Commit 7db397d7 authored by Yuxin Wu's avatar Yuxin Wu

refactor inference runners

parent 3bc0bed2
...@@ -8,14 +8,13 @@ from collections import namedtuple ...@@ -8,14 +8,13 @@ from collections import namedtuple
import tqdm import tqdm
import six import six
import copy import copy
from six.moves import zip, range from six.moves import zip
from ..utils import logger, get_tqdm_kwargs, get_tqdm from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils import TowerContext
from ..train.input_data import TensorInput, FeedInput from ..train.input_data import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder, OnlinePredictor
from .base import Triggerable from .base import Triggerable
from .inference import Inferencer from .inference import Inferencer
...@@ -27,6 +26,9 @@ class OutputTensorDispatcher(object): ...@@ -27,6 +26,9 @@ class OutputTensorDispatcher(object):
def __init__(self): def __init__(self):
self._names = [] self._names = []
self._idxs = [] self._idxs = []
# each element in idxs is a list
# len(idxs) == len(inferencer)
# the list contains the indices into names
def add_entry(self, names): def add_entry(self, names):
v = [] v = []
...@@ -72,52 +74,50 @@ class InferenceRunner(Triggerable): ...@@ -72,52 +74,50 @@ class InferenceRunner(Triggerable):
_IOTensor = namedtuple('IOTensor', ['index', 'isOutput']) _IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, ds, infs, input_tensor_names=None): def __init__(self, ds, infs, input_names=None):
""" """
Args: Args:
ds (DataFlow): the DataFlow to run inferencer on. ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances. infs (list): a list of `Inferencer` instances.
input_tensor_names(list): list of tensors to feed the dataflow to. input_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders. Defaults to all the input placeholders.
""" """
if isinstance(ds, DataFlow): if isinstance(ds, DataFlow):
self.ds = FeedInput(ds) self._input_data = FeedInput(ds)
assert isinstance(self.ds, FeedInput), self.ds assert isinstance(self._input_data, FeedInput), self._input_data
if not isinstance(infs, list): if not isinstance(infs, list):
self.infs = [infs] self.infs = [infs]
else: else:
self.infs = infs self.infs = infs
for v in self.infs: for v in self.infs:
assert isinstance(v, Inferencer), v assert isinstance(v, Inferencer), v
self.input_names = input_tensor_names # names actually self.input_names = input_names # names actually
self._prefix = ''
def _setup_graph(self): def _setup_input_names(self):
self._find_input_tensors() # these are all tensor names # just use all the placeholders, if input_name is None
self._find_output_tensors() # may be either tensor name or op name
self.predictor = self.trainer.get_predictor(
self.input_names, self.output_names)
def _find_input_tensors(self):
if self.input_names is None: if self.input_names is None:
input_vars = self.trainer.model.get_reused_placehdrs() inputs = self.trainer.model.get_reused_placehdrs()
# TODO even if it works here, sparse still is unavailable self.input_names = [x.name for x in inputs]
# TODO sparse. even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse # because get_tensor_by_name doesn't work for sparse
def get_name(x): # def get_name(x):
if isinstance(x, tf.SparseTensor): # if isinstance(x, tf.SparseTensor):
return x.op.name.split('/')[0] # return x.op.name.split('/')[0]
return x.name # return x.name
self.input_names = [get_name(x) for x in input_vars]
def _find_output_tensors(self): def _setup_output_names(self):
dispatcher = OutputTensorDispatcher() dispatcher = OutputTensorDispatcher()
for inf in self.infs: for inf in self.infs:
dispatcher.add_entry(inf.get_output_tensors()) dispatcher.add_entry(inf.get_output_tensors())
all_names = dispatcher.get_all_names() all_names = dispatcher.get_all_names()
IOTensor = InferenceRunner._IOTensor # output names can be input placeholders, use IOTensor
self.output_names = list(filter( self.output_names = list(filter(
lambda x: x not in self.input_names, all_names)) lambda x: x not in self.input_names, all_names))
IOTensor = InferenceRunner._IOTensor
def find_tensors(names): def find_tensors(names):
ret = [] ret = []
...@@ -130,13 +130,43 @@ class InferenceRunner(Triggerable): ...@@ -130,13 +130,43 @@ class InferenceRunner(Triggerable):
self.inf_to_tensors = [find_tensors(t) for t in dispatcher.get_names_for_each_entry()] self.inf_to_tensors = [find_tensors(t) for t in dispatcher.get_names_for_each_entry()]
# list of list of IOTensor # list of list of IOTensor
def _setup_graph(self):
self._input_data.setup(self.trainer.model)
self._setup_input_names()
# set self.output_names from inferencers, as well as the name dispatcher
self._setup_output_names()
in_tensors = self._find_input_tensors()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(0)
feed_tensors = self._find_feed_tensors()
out_tensors = self._find_output_tensors()
self.predictor = OnlinePredictor(feed_tensors, out_tensors)
def _find_input_tensors(self):
return self.trainer.model.get_reused_placehdrs()
def _find_feed_tensors(self):
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, self.input_names, 0, prefix=self._prefix)
def _find_output_tensors(self):
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, self.output_names, 0, prefix=self._prefix)
def _trigger(self): def _trigger(self):
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_inference()
self.ds.reset_state() self._input_data.reset_state()
for _ in tqdm.trange(self.ds.size(), **get_tqdm_kwargs()): for _ in tqdm.trange(self._input_data.size(), **get_tqdm_kwargs()):
dp = self.ds.next_feed() dp = self._input_data.next_feed()
outputs = self.predictor(dp) outputs = self.predictor(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors): for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index] inf_output = [(outputs if k.isOutput else dp)[k.index]
...@@ -148,16 +178,16 @@ class InferenceRunner(Triggerable): ...@@ -148,16 +178,16 @@ class InferenceRunner(Triggerable):
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
class FeedfreeInferenceRunner(Triggerable): class FeedfreeInferenceRunner(InferenceRunner):
""" A callback that runs a list of :class:`Inferencer` on some """ A callback that runs a list of :class:`Inferencer` on some
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading :class:`TensorInput`, such as some tensor from a TensorFlow data reading
pipeline. pipeline.
""" """
def __init__(self, input, infs, input_names=None, prefix=''): def __init__(self, input, infs, input_names=None, prefix=''):
""" """
Args: Args:
input (FeedfreeInput): the input to use. Must have ``size()``. input (TensorInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run. infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names in InputDesc. input_names (list): must be a subset of the names in InputDesc.
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
...@@ -173,7 +203,7 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -173,7 +203,7 @@ class FeedfreeInferenceRunner(Triggerable):
assert isinstance(v, Inferencer), v assert isinstance(v, Inferencer), v
if input_names is not None: if input_names is not None:
assert isinstance(input_names, list) assert isinstance(input_names, list)
self._input_names = input_names self.input_names = input_names
try: try:
self._size = input.size() self._size = input.size()
...@@ -181,68 +211,47 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -181,68 +211,47 @@ class FeedfreeInferenceRunner(Triggerable):
raise ValueError("Input used in FeedfreeInferencecRunner must have a size!") raise ValueError("Input used in FeedfreeInferencecRunner must have a size!")
self._prefix = prefix self._prefix = prefix
def _setup_graph(self): def _setup_input_names(self):
self._find_input_tensors() # tensors super(FeedfreeInferenceRunner, self)._setup_input_names()
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
for n in self.input_names:
opname = get_op_tensor_name(n)[0]
assert opname in placeholder_names, \
"[FeedfreeInferenceRunner] name {} is not a model input!".format(n)
# TODO can we reuse predictor factory? def _setup_output_names(self):
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
self.trainer.model.build_graph(self._input_tensors)
PredictorTowerBuilder(fn, self._prefix).build(0)
self._tower_prefix = TowerContext.get_predict_tower_name(0, self._prefix)
self._find_output_tensors()
def _find_input_tensors(self):
self._input_data.setup(self.trainer.model)
# only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = copy.copy(self.trainer.model.get_reused_placehdrs())
if self._input_names is not None:
raise NotImplementedError("Random code. Not tested.")
assert len(self._input_names) == len(self._input_tensors), \
"[FeedfreeInferenceRunner] input_names must have the same length as the input data."
for n, tensor in zip(self.input_names, self._input_tensors):
opname, _ = get_op_tensor_name(n)
for idx, hdr in enumerate(model_placehdrs):
if hdr.name == opname:
model_placehdrs[idx] = tensor
break
else:
raise ValueError(
"{} doesn't appear in the InputDesc of the model!".format(n))
self._input_tensors = model_placehdrs
assert len(self._input_tensors) == len(model_placehdrs), \
"[FeedfreeInferenceRunner] Unmatched length of input tensors!"
def _find_output_tensors(self):
# TODO doesn't support output an input tensor
dispatcher = OutputTensorDispatcher() dispatcher = OutputTensorDispatcher()
for inf in self.infs: for inf in self.infs:
dispatcher.add_entry(inf.get_output_tensors()) dispatcher.add_entry(inf.get_output_tensors())
all_names = dispatcher.get_all_names() self.output_names = dispatcher.get_all_names()
G = tf.get_default_graph() # TODO check names. doesn't support output an input tensor (but can support)
self._output_tensors = [G.get_tensor_by_name(
self._tower_prefix + '/' + n) for n in all_names]
# list of list of id IOTensor = InferenceRunner._IOTensor
self.inf_to_idxs = dispatcher.get_idx_for_each_entry()
def _trigger(self): def find_tensors(names):
sess = tf.get_default_session() return [IOTensor(self.output_names.index(n), True) for n in names]
self.inf_to_tensors = [find_tensors(t) for t in dispatcher.get_names_for_each_entry()]
for inf in self.infs: def _find_feed_tensors(self):
inf.before_inference() return []
with get_tqdm(total=self._size) as pbar: def _find_input_tensors(self):
for _ in range(self._size): tensors = self._input_data.get_input_tensors()
outputs = sess.run(fetches=self._output_tensors)
for inf, idlist in zip(self.infs, self.inf_to_idxs): assert len(self.input_names) == len(tensors), \
inf_output = [outputs[k] for k in idlist] "[FeedfreeInferenceRunner] Input names must match the " \
inf.datapoint(inf_output) "length of the input data, but {} != {}".format(len(self.input_names), len(tensors))
pbar.update() # use placeholders for the unused inputs, use TensorInput for the used inpupts
self._write_summary_after_inference() ret = copy.copy(self.trainer.model.get_reused_placehdrs())
for name, tensor in zip(self.input_names, tensors):
tname = get_op_tensor_name(name)[1]
for idx, hdr in enumerate(ret):
if hdr.name == tname:
ret[idx] = tensor
break
else:
assert tname in set([k.name for k in ret]), tname
return ret
def _write_summary_after_inference(self): def _write_summary_after_inference(self):
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
...@@ -80,6 +80,7 @@ class ModelDesc(object): ...@@ -80,6 +80,7 @@ class ModelDesc(object):
for v in input_vars: for v in input_vars:
tf.add_to_collection(INPUTS_KEY, v.dumps()) tf.add_to_collection(INPUTS_KEY, v.dumps())
ret = [] ret = []
with tf.name_scope(None): # clear any name scope it might get called in
for v in input_vars: for v in input_vars:
placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder
ret.append(placehdr_f( ret.append(placehdr_f(
......
...@@ -85,6 +85,7 @@ class FeedfreeInput(InputData): ...@@ -85,6 +85,7 @@ class FeedfreeInput(InputData):
pass pass
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
class EnqueueThread(ShareSessionThread): class EnqueueThread(ShareSessionThread):
def __init__(self, queue, ds, input_placehdrs): def __init__(self, queue, ds, input_placehdrs):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment