Commit 7db397d7 authored by Yuxin Wu's avatar Yuxin Wu

refactor inference runners

parent 3bc0bed2
......@@ -8,14 +8,13 @@ from collections import namedtuple
import tqdm
import six
import copy
from six.moves import zip, range
from six.moves import zip
from ..utils import logger, get_tqdm_kwargs, get_tqdm
from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name
from ..tfutils import TowerContext
from ..train.input_data import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder
from ..predict import PredictorTowerBuilder, OnlinePredictor
from .base import Triggerable
from .inference import Inferencer
......@@ -27,6 +26,9 @@ class OutputTensorDispatcher(object):
def __init__(self):
self._names = []
self._idxs = []
# each element in idxs is a list
# len(idxs) == len(inferencer)
# the list contains the indices into names
def add_entry(self, names):
v = []
......@@ -72,52 +74,50 @@ class InferenceRunner(Triggerable):
_IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensor_names=None):
def __init__(self, ds, infs, input_names=None):
"""
Args:
ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
input_tensor_names(list): list of tensors to feed the dataflow to.
input_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders.
"""
if isinstance(ds, DataFlow):
self.ds = FeedInput(ds)
assert isinstance(self.ds, FeedInput), self.ds
self._input_data = FeedInput(ds)
assert isinstance(self._input_data, FeedInput), self._input_data
if not isinstance(infs, list):
self.infs = [infs]
else:
self.infs = infs
for v in self.infs:
assert isinstance(v, Inferencer), v
self.input_names = input_tensor_names # names actually
self.input_names = input_names # names actually
self._prefix = ''
def _setup_graph(self):
self._find_input_tensors() # these are all tensor names
self._find_output_tensors() # may be either tensor name or op name
self.predictor = self.trainer.get_predictor(
self.input_names, self.output_names)
def _find_input_tensors(self):
def _setup_input_names(self):
# just use all the placeholders, if input_name is None
if self.input_names is None:
input_vars = self.trainer.model.get_reused_placehdrs()
# TODO even if it works here, sparse still is unavailable
inputs = self.trainer.model.get_reused_placehdrs()
self.input_names = [x.name for x in inputs]
# TODO sparse. even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def get_name(x):
if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0]
return x.name
self.input_names = [get_name(x) for x in input_vars]
# def get_name(x):
# if isinstance(x, tf.SparseTensor):
# return x.op.name.split('/')[0]
# return x.name
def _find_output_tensors(self):
def _setup_output_names(self):
dispatcher = OutputTensorDispatcher()
for inf in self.infs:
dispatcher.add_entry(inf.get_output_tensors())
all_names = dispatcher.get_all_names()
IOTensor = InferenceRunner._IOTensor
# output names can be input placeholders, use IOTensor
self.output_names = list(filter(
lambda x: x not in self.input_names, all_names))
IOTensor = InferenceRunner._IOTensor
def find_tensors(names):
ret = []
......@@ -130,13 +130,43 @@ class InferenceRunner(Triggerable):
self.inf_to_tensors = [find_tensors(t) for t in dispatcher.get_names_for_each_entry()]
# list of list of IOTensor
def _setup_graph(self):
self._input_data.setup(self.trainer.model)
self._setup_input_names()
# set self.output_names from inferencers, as well as the name dispatcher
self._setup_output_names()
in_tensors = self._find_input_tensors()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(0)
feed_tensors = self._find_feed_tensors()
out_tensors = self._find_output_tensors()
self.predictor = OnlinePredictor(feed_tensors, out_tensors)
def _find_input_tensors(self):
return self.trainer.model.get_reused_placehdrs()
def _find_feed_tensors(self):
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, self.input_names, 0, prefix=self._prefix)
def _find_output_tensors(self):
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, self.output_names, 0, prefix=self._prefix)
def _trigger(self):
for inf in self.infs:
inf.before_inference()
self.ds.reset_state()
for _ in tqdm.trange(self.ds.size(), **get_tqdm_kwargs()):
dp = self.ds.next_feed()
self._input_data.reset_state()
for _ in tqdm.trange(self._input_data.size(), **get_tqdm_kwargs()):
dp = self._input_data.next_feed()
outputs = self.predictor(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index]
......@@ -148,16 +178,16 @@ class InferenceRunner(Triggerable):
summary_inferencer(self.trainer, self.infs)
class FeedfreeInferenceRunner(Triggerable):
class FeedfreeInferenceRunner(InferenceRunner):
""" A callback that runs a list of :class:`Inferencer` on some
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading
:class:`TensorInput`, such as some tensor from a TensorFlow data reading
pipeline.
"""
def __init__(self, input, infs, input_names=None, prefix=''):
"""
Args:
input (FeedfreeInput): the input to use. Must have ``size()``.
input (TensorInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names in InputDesc.
prefix(str): an prefix used to build the tower. Must be set
......@@ -173,7 +203,7 @@ class FeedfreeInferenceRunner(Triggerable):
assert isinstance(v, Inferencer), v
if input_names is not None:
assert isinstance(input_names, list)
self._input_names = input_names
self.input_names = input_names
try:
self._size = input.size()
......@@ -181,68 +211,47 @@ class FeedfreeInferenceRunner(Triggerable):
raise ValueError("Input used in FeedfreeInferencecRunner must have a size!")
self._prefix = prefix
def _setup_graph(self):
self._find_input_tensors() # tensors
# TODO can we reuse predictor factory?
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
self.trainer.model.build_graph(self._input_tensors)
PredictorTowerBuilder(fn, self._prefix).build(0)
self._tower_prefix = TowerContext.get_predict_tower_name(0, self._prefix)
self._find_output_tensors()
def _setup_input_names(self):
super(FeedfreeInferenceRunner, self)._setup_input_names()
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
for n in self.input_names:
opname = get_op_tensor_name(n)[0]
assert opname in placeholder_names, \
"[FeedfreeInferenceRunner] name {} is not a model input!".format(n)
def _find_input_tensors(self):
self._input_data.setup(self.trainer.model)
# only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = copy.copy(self.trainer.model.get_reused_placehdrs())
if self._input_names is not None:
raise NotImplementedError("Random code. Not tested.")
assert len(self._input_names) == len(self._input_tensors), \
"[FeedfreeInferenceRunner] input_names must have the same length as the input data."
for n, tensor in zip(self.input_names, self._input_tensors):
opname, _ = get_op_tensor_name(n)
for idx, hdr in enumerate(model_placehdrs):
if hdr.name == opname:
model_placehdrs[idx] = tensor
break
else:
raise ValueError(
"{} doesn't appear in the InputDesc of the model!".format(n))
self._input_tensors = model_placehdrs
assert len(self._input_tensors) == len(model_placehdrs), \
"[FeedfreeInferenceRunner] Unmatched length of input tensors!"
def _find_output_tensors(self):
# TODO doesn't support output an input tensor
def _setup_output_names(self):
dispatcher = OutputTensorDispatcher()
for inf in self.infs:
dispatcher.add_entry(inf.get_output_tensors())
all_names = dispatcher.get_all_names()
G = tf.get_default_graph()
self._output_tensors = [G.get_tensor_by_name(
self._tower_prefix + '/' + n) for n in all_names]
self.output_names = dispatcher.get_all_names()
# TODO check names. doesn't support output an input tensor (but can support)
# list of list of id
self.inf_to_idxs = dispatcher.get_idx_for_each_entry()
IOTensor = InferenceRunner._IOTensor
def _trigger(self):
sess = tf.get_default_session()
def find_tensors(names):
return [IOTensor(self.output_names.index(n), True) for n in names]
self.inf_to_tensors = [find_tensors(t) for t in dispatcher.get_names_for_each_entry()]
for inf in self.infs:
inf.before_inference()
def _find_feed_tensors(self):
return []
with get_tqdm(total=self._size) as pbar:
for _ in range(self._size):
outputs = sess.run(fetches=self._output_tensors)
for inf, idlist in zip(self.infs, self.inf_to_idxs):
inf_output = [outputs[k] for k in idlist]
inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
def _find_input_tensors(self):
tensors = self._input_data.get_input_tensors()
assert len(self.input_names) == len(tensors), \
"[FeedfreeInferenceRunner] Input names must match the " \
"length of the input data, but {} != {}".format(len(self.input_names), len(tensors))
# use placeholders for the unused inputs, use TensorInput for the used inpupts
ret = copy.copy(self.trainer.model.get_reused_placehdrs())
for name, tensor in zip(self.input_names, tensors):
tname = get_op_tensor_name(name)[1]
for idx, hdr in enumerate(ret):
if hdr.name == tname:
ret[idx] = tensor
break
else:
assert tname in set([k.name for k in ret]), tname
return ret
def _write_summary_after_inference(self):
summary_inferencer(self.trainer, self.infs)
......@@ -80,11 +80,12 @@ class ModelDesc(object):
for v in input_vars:
tf.add_to_collection(INPUTS_KEY, v.dumps())
ret = []
for v in input_vars:
placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder
ret.append(placehdr_f(
v.type, shape=v.shape,
name=prefix + v.name))
with tf.name_scope(None): # clear any name scope it might get called in
for v in input_vars:
placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder
ret.append(placehdr_f(
v.type, shape=v.shape,
name=prefix + v.name))
return ret
def get_inputs_desc(self):
......
......@@ -85,6 +85,7 @@ class FeedfreeInput(InputData):
pass
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
class EnqueueThread(ShareSessionThread):
def __init__(self, queue, ds, input_placehdrs):
super(EnqueueThread, self).__init__()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment