Commit 5cbf81c2 authored by Yuxin Wu's avatar Yuxin Wu

use HookedSession to do inference. fix #161

parent d3802e79
...@@ -91,8 +91,6 @@ class ScalarStats(Inferencer): ...@@ -91,8 +91,6 @@ class ScalarStats(Inferencer):
self.stats = [] self.stats = []
def _datapoint(self, output): def _datapoint(self, output):
for o in output:
assert isinstance(o, (float, np.float32)), type(o)
self.stats.append(output) self.stats.append(output)
def _after_inference(self): def _after_inference(self):
......
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
from collections import namedtuple from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from abc import ABCMeta, abstractmethod
import tqdm import tqdm
import six import six
import copy import copy
...@@ -14,7 +17,7 @@ from ..utils import logger, get_tqdm_kwargs ...@@ -14,7 +17,7 @@ from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..train.input_data import TensorInput, FeedInput from ..train.input_data import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder, OnlinePredictor from ..predict import PredictorTowerBuilder
from .base import Triggerable from .base import Triggerable
from .inference import Inferencer from .inference import Inferencer
...@@ -22,36 +25,16 @@ from .inference import Inferencer ...@@ -22,36 +25,16 @@ from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner'] __all__ = ['InferenceRunner', 'FeedfreeInferenceRunner']
class OutputTensorDispatcher(object): class InferencerToHook(tf.train.SessionRunHook):
def __init__(self): def __init__(self, inf, fetches):
self._names = [] self._inf = inf
self._idxs = [] self._fetches = fetches
# each element in idxs is a list
# len(idxs) == len(inferencer)
# the list contains the indices into names
def add_entry(self, names):
v = []
for n in names:
tensorname = get_op_tensor_name(n)[1]
if tensorname in self._names:
v.append(self._names.index(tensorname))
else:
self._names.append(tensorname)
v.append(len(self._names) - 1)
self._idxs.append(v)
def get_all_names(self):
return self._names
def get_idx_for_each_entry(self): def before_run(self, _):
return self._idxs return tf.train.SessionRunArgs(fetches=self._fetches)
def get_names_for_each_entry(self): def after_run(self, _, run_values):
ret = [] self._inf.datapoint(run_values.results)
for t in self._idxs:
ret.append([self._names[k] for k in t])
return ret
def summary_inferencer(trainer, infs): def summary_inferencer(trainer, infs):
...@@ -68,33 +51,34 @@ def summary_inferencer(trainer, infs): ...@@ -68,33 +51,34 @@ def summary_inferencer(trainer, infs):
trainer.monitors.put(k, v) trainer.monitors.put(k, v)
class InferenceRunner(Triggerable): @six.add_metaclass(ABCMeta)
""" class InferenceRunnerBase(Triggerable):
A callback that runs a list of :class:`Inferencer` on some """ Base methods for inference runner"""
:class:`DataFlow`. def __init__(self, input, infs, input_names=None, prefix=''):
"""
_IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_names=None):
""" """
Args: Args:
ds (DataFlow): the DataFlow to run inferencer on. input (InputData): the input to use. Must have ``size()``.
infs (list): a list of `Inferencer` instances. infs (list): list of :class:`Inferencer` to run.
input_names(list): list of tensors to feed the dataflow to. input_names (list): must be a subset of the names in InputDesc.
Defaults to all the input placeholders. prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used.
""" """
if isinstance(ds, DataFlow): self._input_data = input
self._input_data = FeedInput(ds)
assert isinstance(self._input_data, FeedInput), self._input_data
if not isinstance(infs, list): if not isinstance(infs, list):
self.infs = [infs] self.infs = [infs]
else: else:
self.infs = infs self.infs = infs
for v in self.infs: for v in self.infs:
assert isinstance(v, Inferencer), v assert isinstance(v, Inferencer), v
self.input_names = input_names # names actually if input_names is not None:
self._prefix = '' assert isinstance(input_names, list)
self.input_names = input_names
try:
self._size = input.size()
except NotImplementedError:
raise ValueError("Input used in InferenceRunner must have a size!")
self._prefix = prefix
def _setup_input_names(self): def _setup_input_names(self):
# just use all the placeholders, if input_name is None # just use all the placeholders, if input_name is None
...@@ -110,34 +94,9 @@ class InferenceRunner(Triggerable): ...@@ -110,34 +94,9 @@ class InferenceRunner(Triggerable):
# return x.op.name.split('/')[0] # return x.op.name.split('/')[0]
# return x.name # return x.name
def _setup_output_names(self):
dispatcher = OutputTensorDispatcher()
for inf in self.infs:
dispatcher.add_entry(inf.get_output_tensors())
all_names = dispatcher.get_all_names()
# output names can be input placeholders, use IOTensor
self.output_names = list(filter(
lambda x: x not in self.input_names, all_names))
IOTensor = InferenceRunner._IOTensor
def find_tensors(names):
ret = []
for name in names:
if name in self.input_names:
ret.append(IOTensor(self.input_names.index(name), False))
else:
ret.append(IOTensor(self.output_names.index(name), True))
return ret
self.inf_to_tensors = [find_tensors(t) for t in dispatcher.get_names_for_each_entry()]
# list of list of IOTensor
def _setup_graph(self): def _setup_graph(self):
self._input_data.setup(self.trainer.model) self._input_data.setup(self.trainer.model)
self._setup_input_names() self._setup_input_names()
# set self.output_names from inferencers, as well as the name dispatcher
self._setup_output_names()
in_tensors = self._find_input_tensors() in_tensors = self._find_input_tensors()
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
...@@ -145,42 +104,73 @@ class InferenceRunner(Triggerable): ...@@ -145,42 +104,73 @@ class InferenceRunner(Triggerable):
self.trainer.model.build_graph(in_tensors) self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(0) PredictorTowerBuilder(fn, self._prefix).build(0)
feed_tensors = self._find_feed_tensors() self._feed_tensors = self._find_feed_tensors()
out_tensors = self._find_output_tensors() self._hooks = [self._build_hook(inf) for inf in self.infs]
self.predictor = OnlinePredictor(feed_tensors, out_tensors)
def _find_input_tensors(self): def _before_train(self):
return self.trainer.model.get_reused_placehdrs() self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
def _find_feed_tensors(self): def _get_tensors_maybe_in_tower(self, names):
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()]) placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, self.input_names, 0, prefix=self._prefix) return get_tensor_fn(placeholder_names, names, 0, prefix=self._prefix)
def _find_output_tensors(self): @abstractmethod
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()]) def _find_input_tensors(self):
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower pass
return get_tensor_fn(placeholder_names, self.output_names, 0, prefix=self._prefix)
@abstractmethod
def _find_feed_tensors(self):
pass
@abstractmethod
def _build_hook(self, inf):
pass
def _trigger(self): def _trigger(self):
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_inference()
# iterate over the data, and run the hooked session
self._input_data.reset_state() self._input_data.reset_state()
for _ in tqdm.trange(self._input_data.size(), **get_tqdm_kwargs()): for _ in tqdm.trange(self._input_data.size(), **get_tqdm_kwargs()):
dp = self._input_data.next_feed() dp = self._input_data.next_feed()
outputs = self.predictor(dp) feed = dict(zip(self._feed_tensors, dp))
for inf, tensormap in zip(self.infs, self.inf_to_tensors): self._hooked_sess.run(fetches=[], feed_dict=feed)
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
inf.datapoint(inf_output)
self._write_summary_after_inference()
def _write_summary_after_inference(self):
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
class FeedfreeInferenceRunner(InferenceRunner): class InferenceRunner(InferenceRunnerBase):
"""
A callback that runs a list of :class:`Inferencer` on some
:class:`DataFlow`.
"""
def __init__(self, ds, infs, input_names=None):
"""
Args:
ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
input_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders.
"""
assert isinstance(ds, DataFlow), ds
input = FeedInput(ds)
super(InferenceRunner, self).__init__(input, infs, input_names, prefix='')
def _find_input_tensors(self):
return self.trainer.model.get_reused_placehdrs()
def _find_feed_tensors(self):
return self._get_tensors_maybe_in_tower(self.input_names)
def _build_hook(self, inf):
out_names = inf.get_output_tensors()
fetches = self._get_tensors_maybe_in_tower(out_names)
return InferencerToHook(inf, fetches)
class FeedfreeInferenceRunner(InferenceRunnerBase):
""" A callback that runs a list of :class:`Inferencer` on some """ A callback that runs a list of :class:`Inferencer` on some
:class:`TensorInput`, such as some tensor from a TensorFlow data reading :class:`TensorInput`, such as some tensor from a TensorFlow data reading
pipeline. pipeline.
...@@ -196,22 +186,7 @@ class FeedfreeInferenceRunner(InferenceRunner): ...@@ -196,22 +186,7 @@ class FeedfreeInferenceRunner(InferenceRunner):
differently if more than one :class:`FeedfreeInferenceRunner` are used. differently if more than one :class:`FeedfreeInferenceRunner` are used.
""" """
assert isinstance(input, TensorInput), input assert isinstance(input, TensorInput), input
self._input_data = input super(FeedfreeInferenceRunner, self).__init__(input, infs, input_names, prefix)
if not isinstance(infs, list):
self.infs = [infs]
else:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), v
if input_names is not None:
assert isinstance(input_names, list)
self.input_names = input_names
try:
self._size = input.size()
except NotImplementedError:
raise ValueError("Input used in FeedfreeInferencecRunner must have a size!")
self._prefix = prefix
def _setup_input_names(self): def _setup_input_names(self):
super(FeedfreeInferenceRunner, self)._setup_input_names() super(FeedfreeInferenceRunner, self)._setup_input_names()
...@@ -221,22 +196,6 @@ class FeedfreeInferenceRunner(InferenceRunner): ...@@ -221,22 +196,6 @@ class FeedfreeInferenceRunner(InferenceRunner):
assert opname in placeholder_names, \ assert opname in placeholder_names, \
"[FeedfreeInferenceRunner] name {} is not a model input!".format(n) "[FeedfreeInferenceRunner] name {} is not a model input!".format(n)
def _setup_output_names(self):
dispatcher = OutputTensorDispatcher()
for inf in self.infs:
dispatcher.add_entry(inf.get_output_tensors())
self.output_names = dispatcher.get_all_names()
# TODO check names. doesn't support output an input tensor (but can support)
IOTensor = InferenceRunner._IOTensor
def find_tensors(names):
return [IOTensor(self.output_names.index(n), True) for n in names]
self.inf_to_tensors = [find_tensors(t) for t in dispatcher.get_names_for_each_entry()]
def _find_feed_tensors(self):
return []
def _find_input_tensors(self): def _find_input_tensors(self):
tensors = self._input_data.get_input_tensors() tensors = self._input_data.get_input_tensors()
...@@ -252,8 +211,22 @@ class FeedfreeInferenceRunner(InferenceRunner): ...@@ -252,8 +211,22 @@ class FeedfreeInferenceRunner(InferenceRunner):
ret[idx] = tensor ret[idx] = tensor
break break
else: else:
assert tname in set([k.name for k in ret]), tname assert tname in set([k.name for k in ret]), \
"Input name {} is not among model inputs: {}!".format(tname, ret)
self._input_tensors = ret
return ret return ret
def _write_summary_after_inference(self): def _find_feed_tensors(self):
summary_inferencer(self.trainer, self.infs) return []
def _build_hook(self, inf):
out_names = inf.get_output_tensors() # all is tensorname
placeholder_names = [k.name + ':0' for k in self.trainer.model.get_inputs_desc()]
ret = []
for name in out_names:
if name not in placeholder_names:
ret.append(self._get_tensors_maybe_in_tower([name])[0])
else: # requesting an input
idx = placeholder_names.index(name)
ret.append(self._input_tensors[idx])
return InferencerToHook(inf, ret)
...@@ -194,6 +194,10 @@ class PredictorTowerBuilder(object): ...@@ -194,6 +194,10 @@ class PredictorTowerBuilder(object):
@staticmethod @staticmethod
def get_tensors_maybe_in_tower(placeholder_names, names, k, prefix=''): def get_tensors_maybe_in_tower(placeholder_names, names, k, prefix=''):
"""
Args:
placeholders (list): A list of __op__ name.
"""
def maybe_inside_tower(name): def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0] name = get_op_tensor_name(name)[0]
if name in placeholder_names: if name in placeholder_names:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment