Commit 49cdbf4f authored by Yuxin Wu's avatar Yuxin Wu

feedfreeinputinferencer

parent 84790b78
...@@ -10,7 +10,7 @@ from six.moves import zip ...@@ -10,7 +10,7 @@ from six.moves import zip
from ..utils import logger from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_var_name from ..tfutils import get_op_tensor_name
__all__ = ['ScalarStats', 'Inferencer', __all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats'] 'ClassificationError', 'BinaryClassificationStats']
...@@ -56,9 +56,10 @@ class Inferencer(object): ...@@ -56,9 +56,10 @@ class Inferencer(object):
def get_output_tensors(self): def get_output_tensors(self):
""" """
Return a list of tensor names this inferencer needed. Return a list of tensor names (guranteed not op name) this inferencer needs.
""" """
return self._get_output_tensors() ret = self._get_output_tensors()
return [get_op_tensor_name(n)[1] for n in ret]
@abstractmethod @abstractmethod
def _get_output_tensors(self): def _get_output_tensors(self):
...@@ -101,7 +102,7 @@ class ScalarStats(Inferencer): ...@@ -101,7 +102,7 @@ class ScalarStats(Inferencer):
ret = {} ret = {}
for stat, name in zip(self.stats, self.names): for stat, name in zip(self.stats, self.names):
opname, _ = get_op_var_name(name) opname, _ = get_op_tensor_name(name)
name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname
ret[name] = stat ret[name] = stat
return ret return ret
...@@ -129,11 +130,11 @@ class ClassificationError(Inferencer): ...@@ -129,11 +130,11 @@ class ClassificationError(Inferencer):
:meth:`prediction_incorrect`. :meth:`prediction_incorrect`.
summary_name(str): the name for logging. summary_name(str): the name for logging.
""" """
self.wrong_var_name = wrong_tensor_name self.wrong_tensor_name = wrong_tensor_name
self.summary_name = summary_name self.summary_name = summary_name
def _get_output_tensors(self): def _get_output_tensors(self):
return [self.wrong_var_name] return [self.wrong_tensor_name]
def _before_inference(self): def _before_inference(self):
self.err_stat = RatioCounter() self.err_stat = RatioCounter()
...@@ -145,7 +146,7 @@ class ClassificationError(Inferencer): ...@@ -145,7 +146,7 @@ class ClassificationError(Inferencer):
sys.exit(1) sys.exit(1)
else: else:
# TODO put shape assertion into inferencerrunner # TODO put shape assertion into inferencerrunner
assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_var_name) assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_tensor_name)
batch_size = len(vec) batch_size = len(vec)
wrong = np.sum(vec) wrong = np.sum(vec)
self.err_stat.feed(wrong, batch_size) self.err_stat.feed(wrong, batch_size)
......
...@@ -9,15 +9,15 @@ import six ...@@ -9,15 +9,15 @@ import six
from six.moves import zip, range from six.moves import zip, range
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import logger, get_tqdm, PREDICT_TOWER from ..utils import logger, get_tqdm, PREDICT_TOWER, SUMMARY_BACKUP_KEYS
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name, freeze_collection
from ..train.input_data import FeedfreeInput from ..train.input_data import FeedfreeInput
from ..predict import build_prediction_graph from ..predict import build_prediction_graph
from .base import Callback from .base import Callback
from .inference import Inferencer from .inference import Inferencer
__all__ = ['InferenceRunner'] __all__ = ['InferenceRunner', 'FeedfreeInferenceRunner']
class OutputTensorDispatcer(object): class OutputTensorDispatcer(object):
...@@ -42,6 +42,12 @@ class OutputTensorDispatcer(object): ...@@ -42,6 +42,12 @@ class OutputTensorDispatcer(object):
def get_idx_for_each_entry(self): def get_idx_for_each_entry(self):
return self._idxs return self._idxs
def get_names_for_each_entry(self):
ret = []
for t in self._idxs:
ret.append([self._names[k] for k in t])
return ret
def summary_inferencer(trainer, infs): def summary_inferencer(trainer, infs):
for inf in infs: for inf in infs:
...@@ -109,17 +115,16 @@ class InferenceRunner(Callback): ...@@ -109,17 +115,16 @@ class InferenceRunner(Callback):
self.output_tensors = list(filter( self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names)) lambda x: x not in self.input_tensors, all_names))
def find_oid(idxs): def find_tensors(names):
ret = [] ret = []
for idx in idxs: for name in names:
name = all_names[idx]
if name in self.input_tensors: if name in self.input_tensors:
ret.append(IOTensor(self.input_tensors.index(name), False)) ret.append(IOTensor(self.input_tensors.index(name), False))
else: else:
ret.append(IOTensor(self.output_tensors.index(name), True)) ret.append(IOTensor(self.output_tensors.index(name), True))
return ret return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()] self.inf_to_tensors = [find_tensors(t) for t in dispatcer.get_names_for_each_entry()]
# list of list of (var_name: IOTensor) # list of list of IOTensor
def _trigger_epoch(self): def _trigger_epoch(self):
for inf in self.infs: for inf in self.infs:
...@@ -141,11 +146,16 @@ class InferenceRunner(Callback): ...@@ -141,11 +146,16 @@ class InferenceRunner(Callback):
class FeedfreeInferenceRunner(Callback): class FeedfreeInferenceRunner(Callback):
IOTensor = namedtuple('IOTensor', ['index', 'isOutput']) """ A callback that runs a list of :class:`Inferencer` on some
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading
pipeline.
"""
def __init__(self, input, infs, input_names=None): def __init__(self, input, infs, input_names=None):
""" """
Args: Args:
input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names of InputVar. input_names (list): must be a subset of the names of InputVar.
""" """
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
...@@ -168,6 +178,11 @@ class FeedfreeInferenceRunner(Callback): ...@@ -168,6 +178,11 @@ class FeedfreeInferenceRunner(Callback):
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # tensors self._find_input_tensors() # tensors
tf.get_variable_scope().reuse_variables()
# overwrite the FeedfreeInferenceRunner scope
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
def fn(_): def fn(_):
self.trainer.model.build_graph(self._input_tensors) self.trainer.model.build_graph(self._input_tensors)
build_prediction_graph(fn, [0]) build_prediction_graph(fn, [0])
...@@ -180,9 +195,9 @@ class FeedfreeInferenceRunner(Callback): ...@@ -180,9 +195,9 @@ class FeedfreeInferenceRunner(Callback):
# only 1 prediction tower will be used for inference # only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors() self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reuse_placehdrs() model_placehdrs = self.trainer.model.get_reuse_placehdrs()
if self.input_names is not None: if self._input_names is not None:
raise NotImplementedError("Random code. Not tested.") raise NotImplementedError("Random code. Not tested.")
assert len(self.input_names) == len(self._input_tensors), \ assert len(self._input_names) == len(self._input_tensors), \
"[FeedfreeInferenceRunner] input_names must have the same length as the input data." "[FeedfreeInferenceRunner] input_names must have the same length as the input data."
for n, tensor in zip(self.input_names, self._input_tensors): for n, tensor in zip(self.input_names, self._input_tensors):
opname, _ = get_op_tensor_name(n) opname, _ = get_op_tensor_name(n)
...@@ -200,36 +215,28 @@ class FeedfreeInferenceRunner(Callback): ...@@ -200,36 +215,28 @@ class FeedfreeInferenceRunner(Callback):
def _find_output_tensors(self): def _find_output_tensors(self):
# TODO doesn't support output an input tensor # TODO doesn't support output an input tensor
# TODO find tensors, not names
dispatcer = OutputTensorDispatcer() dispatcer = OutputTensorDispatcer()
for inf in self.infs: for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors()) dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names() all_names = dispatcer.get_all_names()
G = tf.get_default_graph()
self._output_tensors = [G.get_tensor_by_name(
self._tower_prefix + '/' + n) for n in all_names]
self._sess = self.trainer.sess
IOTensor = FeedfreeInferenceRunner.IOTensor # list of list of id
self.output_tensors = all_names self.inf_to_idxs = dispatcer.get_idx_for_each_entry()
def find_oid(idxs):
ret = []
for idx in idxs:
name = all_names[idx]
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
def _trigger_epoch(self): def _trigger_epoch(self):
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_inference()
sz = self._input_data.size() with get_tqdm(total=self._size) as pbar:
with get_tqdm(total=sz) as pbar: for _ in range(self._size):
for _ in range(sz): outputs = self._sess.run(fetches=self._output_tensors)
# outputs = self.pred_func(dp) for inf, idlist in zip(self.infs, self.inf_to_idxs):
# for inf, tensormap in zip(self.infs, self.inf_to_tensors): inf_output = [outputs[k] for k in idlist]
# inf_output = [(outputs if k.isOutput else dp)[k.index] inf.datapoint(inf_output)
# for k in tensormap]
# inf.datapoint(inf_output)
pbar.update() pbar.update()
self._write_summary_after_inference() self._write_summary_after_inference()
......
...@@ -72,14 +72,16 @@ def get_global_step(): ...@@ -72,14 +72,16 @@ def get_global_step():
def get_op_tensor_name(name): def get_op_tensor_name(name):
""" """
Tensor name is assumed to be ``op_name + ':0'`` Will automatically determine if ``name`` is a tensor name (ends with ':x')
or a op name.
If it is an op name, the corresponding tensor name is assumed to be ``op_name + ':0'``.
Args: Args:
name(str): name of an op or a tensor name(str): name of an op or a tensor
Returns: Returns:
tuple: (op_name, tensor_name) tuple: (op_name, tensor_name)
""" """
if name.endswith(':0'): if name[-2] == ':':
return name[:-2], name return name[:-2], name
else: else:
return name, name + ':0' return name, name + ':0'
......
...@@ -75,7 +75,7 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss ...@@ -75,7 +75,7 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
beta = count_neg / (count_neg + count_pos) beta = count_neg / (count_neg + count_pos)
pos_weight = beta / (1 - beta) pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight) cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name) cost = tf.reduce_mean(cost * (1 - beta), name=name)
return cost return cost
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment