Commit 49cdbf4f authored by Yuxin Wu's avatar Yuxin Wu

feedfreeinputinferencer

parent 84790b78
......@@ -10,7 +10,7 @@ from six.moves import zip
from ..utils import logger
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_var_name
from ..tfutils import get_op_tensor_name
__all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats']
......@@ -56,9 +56,10 @@ class Inferencer(object):
def get_output_tensors(self):
"""
Return a list of tensor names this inferencer needed.
Return a list of tensor names (guranteed not op name) this inferencer needs.
"""
return self._get_output_tensors()
ret = self._get_output_tensors()
return [get_op_tensor_name(n)[1] for n in ret]
@abstractmethod
def _get_output_tensors(self):
......@@ -101,7 +102,7 @@ class ScalarStats(Inferencer):
ret = {}
for stat, name in zip(self.stats, self.names):
opname, _ = get_op_var_name(name)
opname, _ = get_op_tensor_name(name)
name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname
ret[name] = stat
return ret
......@@ -129,11 +130,11 @@ class ClassificationError(Inferencer):
:meth:`prediction_incorrect`.
summary_name(str): the name for logging.
"""
self.wrong_var_name = wrong_tensor_name
self.wrong_tensor_name = wrong_tensor_name
self.summary_name = summary_name
def _get_output_tensors(self):
return [self.wrong_var_name]
return [self.wrong_tensor_name]
def _before_inference(self):
self.err_stat = RatioCounter()
......@@ -145,7 +146,7 @@ class ClassificationError(Inferencer):
sys.exit(1)
else:
# TODO put shape assertion into inferencerrunner
assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_var_name)
assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_tensor_name)
batch_size = len(vec)
wrong = np.sum(vec)
self.err_stat.feed(wrong, batch_size)
......
......@@ -9,15 +9,15 @@ import six
from six.moves import zip, range
from ..dataflow import DataFlow
from ..utils import logger, get_tqdm, PREDICT_TOWER
from ..tfutils.common import get_op_tensor_name
from ..utils import logger, get_tqdm, PREDICT_TOWER, SUMMARY_BACKUP_KEYS
from ..tfutils.common import get_op_tensor_name, freeze_collection
from ..train.input_data import FeedfreeInput
from ..predict import build_prediction_graph
from .base import Callback
from .inference import Inferencer
__all__ = ['InferenceRunner']
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner']
class OutputTensorDispatcer(object):
......@@ -42,6 +42,12 @@ class OutputTensorDispatcer(object):
def get_idx_for_each_entry(self):
return self._idxs
def get_names_for_each_entry(self):
ret = []
for t in self._idxs:
ret.append([self._names[k] for k in t])
return ret
def summary_inferencer(trainer, infs):
for inf in infs:
......@@ -109,17 +115,16 @@ class InferenceRunner(Callback):
self.output_tensors = list(filter(
lambda x: x not in self.input_tensors, all_names))
def find_oid(idxs):
def find_tensors(names):
ret = []
for idx in idxs:
name = all_names[idx]
for name in names:
if name in self.input_tensors:
ret.append(IOTensor(self.input_tensors.index(name), False))
else:
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
self.inf_to_tensors = [find_tensors(t) for t in dispatcer.get_names_for_each_entry()]
# list of list of IOTensor
def _trigger_epoch(self):
for inf in self.infs:
......@@ -141,11 +146,16 @@ class InferenceRunner(Callback):
class FeedfreeInferenceRunner(Callback):
IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
""" A callback that runs a list of :class:`Inferencer` on some
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading
pipeline.
"""
def __init__(self, input, infs, input_names=None):
"""
Args:
input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names of InputVar.
"""
assert isinstance(input, FeedfreeInput), input
......@@ -168,9 +178,14 @@ class FeedfreeInferenceRunner(Callback):
def _setup_graph(self):
self._find_input_tensors() # tensors
def fn(_):
self.trainer.model.build_graph(self._input_tensors)
build_prediction_graph(fn, [0])
tf.get_variable_scope().reuse_variables()
# overwrite the FeedfreeInferenceRunner scope
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS):
def fn(_):
self.trainer.model.build_graph(self._input_tensors)
build_prediction_graph(fn, [0])
self._tower_prefix = PREDICT_TOWER + '0'
self._find_output_tensors()
......@@ -180,9 +195,9 @@ class FeedfreeInferenceRunner(Callback):
# only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reuse_placehdrs()
if self.input_names is not None:
if self._input_names is not None:
raise NotImplementedError("Random code. Not tested.")
assert len(self.input_names) == len(self._input_tensors), \
assert len(self._input_names) == len(self._input_tensors), \
"[FeedfreeInferenceRunner] input_names must have the same length as the input data."
for n, tensor in zip(self.input_names, self._input_tensors):
opname, _ = get_op_tensor_name(n)
......@@ -200,36 +215,28 @@ class FeedfreeInferenceRunner(Callback):
def _find_output_tensors(self):
# TODO doesn't support output an input tensor
# TODO find tensors, not names
dispatcer = OutputTensorDispatcer()
for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names()
G = tf.get_default_graph()
self._output_tensors = [G.get_tensor_by_name(
self._tower_prefix + '/' + n) for n in all_names]
self._sess = self.trainer.sess
IOTensor = FeedfreeInferenceRunner.IOTensor
self.output_tensors = all_names
def find_oid(idxs):
ret = []
for idx in idxs:
name = all_names[idx]
ret.append(IOTensor(self.output_tensors.index(name), True))
return ret
self.inf_to_tensors = [find_oid(t) for t in dispatcer.get_idx_for_each_entry()]
# list of list of (var_name: IOTensor)
# list of list of id
self.inf_to_idxs = dispatcer.get_idx_for_each_entry()
def _trigger_epoch(self):
for inf in self.infs:
inf.before_inference()
sz = self._input_data.size()
with get_tqdm(total=sz) as pbar:
for _ in range(sz):
# outputs = self.pred_func(dp)
# for inf, tensormap in zip(self.infs, self.inf_to_tensors):
# inf_output = [(outputs if k.isOutput else dp)[k.index]
# for k in tensormap]
# inf.datapoint(inf_output)
with get_tqdm(total=self._size) as pbar:
for _ in range(self._size):
outputs = self._sess.run(fetches=self._output_tensors)
for inf, idlist in zip(self.infs, self.inf_to_idxs):
inf_output = [outputs[k] for k in idlist]
inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
......
......@@ -72,14 +72,16 @@ def get_global_step():
def get_op_tensor_name(name):
"""
Tensor name is assumed to be ``op_name + ':0'``
Will automatically determine if ``name`` is a tensor name (ends with ':x')
or a op name.
If it is an op name, the corresponding tensor name is assumed to be ``op_name + ':0'``.
Args:
name(str): name of an op or a tensor
Returns:
tuple: (op_name, tensor_name)
"""
if name.endswith(':0'):
if name[-2] == ':':
return name[:-2], name
else:
return name, name + ':0'
......
......@@ -75,7 +75,7 @@ def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss
beta = count_neg / (count_neg + count_pos)
pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight)
cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name)
return cost
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment