Commit 76fa8e38 authored by Yuxin Wu's avatar Yuxin Wu

Simplify inference_runner: 1. move input_names mapping to InputSource 2. add DataParallelFeedInput

parent 48f6c267
......@@ -10,14 +10,14 @@ from tensorflow.python.training.monitored_session \
from abc import ABCMeta, abstractmethod
import tqdm
import six
import copy
from six.moves import zip
from six.moves import range
from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..train.input_source import TensorInput, FeedInput
from ..train.input_source import TensorInput, FeedInput, DataParallelFeedInput
from ..train.utils import get_tensors_inputs
from ..predict import PredictorTowerBuilder
from .base import Callback
......@@ -60,11 +60,12 @@ class InferenceRunnerBase(Callback):
"""
Args:
input (InputSource): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names in InputDesc.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list): extra ``SessionRunHook`` to run with the evaluation.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
"""
self._input_source = input
if not isinstance(infs, list):
......@@ -87,33 +88,17 @@ class InferenceRunnerBase(Callback):
extra_hooks = []
self._extra_hooks = extra_hooks
def _setup_input_names(self):
# just use all the placeholders, if input_name is None
if self.input_names is None:
inputs = self.trainer.model.get_reused_placehdrs()
self.input_names = [x.name for x in inputs]
# TODO sparse. even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
# def get_name(x):
# if isinstance(x, tf.SparseTensor):
# return x.op.name.split('/')[0]
# return x.name
def _setup_graph(self):
self._input_source.setup(self.trainer.model)
self._setup_input_names()
# Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0]
in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors
def fn(_):
in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors
self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._feed_tensors = self._find_feed_tensors()
self._hooks = [self._build_hook(inf) for inf in self.infs]
def _before_train(self):
......@@ -128,10 +113,6 @@ class InferenceRunnerBase(Callback):
def _find_input_tensors(self):
pass
@abstractmethod
def _find_feed_tensors(self):
pass
@abstractmethod
def _build_hook(self, inf):
pass
......@@ -143,8 +124,7 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session
self._input_source.reset_state()
for _ in tqdm.trange(self._input_source.size(), **get_tqdm_kwargs()):
dp = self._input_source.next_feed()
feed = dict(zip(self._feed_tensors, dp))
feed = self._input_source.next_feed()
self._hooked_sess.run(fetches=[], feed_dict=feed)
summary_inferencer(self.trainer, self.infs)
......@@ -160,19 +140,15 @@ class InferenceRunner(InferenceRunnerBase):
Args:
ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
input_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
"""
assert isinstance(ds, DataFlow), ds
input = FeedInput(ds)
input = FeedInput(ds, input_names)
super(InferenceRunner, self).__init__(
input, infs, input_names, prefix='', extra_hooks=extra_hooks)
def _find_input_tensors(self):
return self.trainer.model.get_reused_placehdrs()
def _find_feed_tensors(self):
return self._get_tensors_maybe_in_tower(self.input_names)
return self._input_source.get_input_tensors()
def _build_hook(self, inf):
out_names = inf.get_output_tensors()
......@@ -191,7 +167,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
Args:
input (TensorInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names in InputDesc.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
......@@ -199,36 +175,14 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
super(FeedfreeInferenceRunner, self).__init__(
input, infs, input_names, prefix=prefix, extra_hooks=extra_hooks)
def _setup_input_names(self):
super(FeedfreeInferenceRunner, self)._setup_input_names()
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
for n in self.input_names:
opname = get_op_tensor_name(n)[0]
assert opname in placeholder_names, \
"[FeedfreeInferenceRunner] name {} is not a model input!".format(n)
def _find_input_tensors(self):
# TODO move mapping to InputSource
tensors = self._input_source.get_input_tensors()
assert len(self.input_names) == len(tensors), \
"[FeedfreeInferenceRunner] Input names must match the " \
"length of the input data, but {} != {}".format(len(self.input_names), len(tensors))
# use placeholders for the unused inputs, use TensorInput for the used inpupts
ret = copy.copy(self.trainer.model.get_reused_placehdrs())
for name, tensor in zip(self.input_names, tensors):
tname = get_op_tensor_name(name)[1]
for idx, hdr in enumerate(ret):
if hdr.name == tname:
ret[idx] = tensor
break
placeholders = self.trainer.model.get_reused_placehdrs()
if self.input_names is None:
return tensors
else:
assert tname in set([k.name for k in ret]), \
"Input name {} is not among model inputs: {}!".format(tname, ret)
self._input_tensors = ret
return ret
def _find_feed_tensors(self):
return []
return get_tensors_inputs(placeholders, tensors, self.input_names)
def _build_hook(self, inf):
out_names = inf.get_output_tensors() # all is tensorname
......@@ -243,22 +197,24 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, ret)
class DataParallelInferenceRunner(InferenceRunner):
class DataParallelInferenceRunner(InferenceRunnerBase):
def __init__(self, ds, infs, gpus, input_names=None):
super(DataParallelInferenceRunner, self).__init__(ds, infs, input_names)
self._tower_names = [TowerContext.get_predict_tower_name(k)
for k in range(len(gpus))]
input = DataParallelFeedInput(
ds, self._tower_names, input_names=input_names)
super(DataParallelInferenceRunner, self).__init__(
input, infs, input_names)
self._gpus = gpus
def _setup_graph(self):
model = self.trainer.model
self._input_source.setup(model)
self._setup_input_names()
# build graph
def build_tower(k):
towername = TowerContext.get_predict_tower_name(k)
# inputs (placeholders) for this tower only
input_tensors = model.build_placeholders(
prefix=towername + '/')
input_tensors = self._input_source.get_input_tensors()
model.build_graph(input_tensors)
builder = PredictorTowerBuilder(build_tower, prefix=self._prefix)
......@@ -267,7 +223,6 @@ class DataParallelInferenceRunner(InferenceRunner):
builder.build(t)
# setup feeds and hooks
self._feed_tensors = self._find_feed_tensors()
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks = [self._build_hook(inf) for inf in self.infs]
......@@ -278,10 +233,6 @@ class DataParallelInferenceRunner(InferenceRunner):
'/' + n for n in names])
return ret
def _find_feed_tensors(self):
names = self._duplicate_names_across_towers(self.input_names)
return get_tensors_by_names(names)
class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size):
super(DataParallelInferenceRunner.InferencerToHookDataParallel, self).__init__(inf, fetches)
......@@ -322,16 +273,13 @@ class DataParallelInferenceRunner(InferenceRunner):
nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower:
dps = []
for k in self._gpus:
dps.extend(self._input_source.next_feed())
feed = dict(zip(self._feed_tensors, dps))
feed = self._input_source.next_feed()
self._parallel_hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(nr_tower)
total -= nr_tower
# take care of the rest
while total > 0:
dp = self._input_source.next_feed()
feed = dict(zip(self._feed_tensors[:len(dp)], dp))
feed = self._input_source.next_feed(cnt=1)
self._hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(1)
summary_inferencer(self.trainer, self.infs)
......@@ -12,8 +12,9 @@ except ImportError:
from itertools import chain
from abc import ABCMeta, abstractmethod
import six
from six.moves import range
from six.moves import range, zip
from .utils import get_placeholders_by_names
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
......@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback
__all__ = ['InputSource', 'FeedfreeInput',
__all__ = ['InputSource', 'FeedfreeInput', 'DataParallelFeedInput',
'QueueInput', 'BatchQueueInput',
'ZMQInput',
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper']
......@@ -38,8 +39,9 @@ class InputSource(object):
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
list: A list of tensors corresponding to the inputs of the model,
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
"""
def setup(self, model):
......@@ -53,27 +55,37 @@ class InputSource(object):
pass
def next_feed(self):
return []
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return {}
class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds):
def __init__(self, ds, input_names=None):
"""
Args:
ds (DataFlow): the input DataFlow.
input_names (list[str]): input names this DataFlow maps to
"""
assert isinstance(ds, DataFlow), ds
self.ds = ds
self._input_names = input_names
def size(self):
return self.ds.size()
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
rds = RepeatedData(self.ds, -1)
rds.reset_state()
self.data_producer = rds.get_data()
self._all_placehdrs = model.get_reused_placehdrs()
if self._input_names is None:
self._placehdrs_to_feed = self._all_placehdrs
else:
self._placehdrs_to_feed = get_placeholders_by_names(
self._all_placehdrs, self._input_names)
self.reset_state()
def reset_state(self):
rds = RepeatedData(self.ds, -1)
......@@ -81,10 +93,61 @@ class FeedInput(InputSource):
self.data_producer = rds.get_data()
def get_input_tensors(self):
return self.input_placehdrs
return self._all_placehdrs
def next_feed(self):
return next(self.data_producer)
dp = next(self.data_producer)
return dict(zip(self._placehdrs_to_feed, dp))
class DataParallelFeedInput(FeedInput):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
def __init__(self, ds, tower_names, input_names=None):
super(DataParallelFeedInput, self).__init__(ds, input_names)
self._tower_names = tower_names
self._nr_tower = len(tower_names)
def setup(self, model):
self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = []
for tname in self._tower_names:
# build a list of placeholders for each tower
self._placehdrs_per_tower.append(
model.build_placeholders(
prefix=tname + '/'))
# apply input mapping and store results in feed_placehdrs_per_tower
if self._input_names is None:
self._feed_placehdrs_per_tower = self._placehdrs_per_tower
else:
for phdrs, tname in zip(
self._placehdrs_per_tower, self._tower_names):
input_names = [tname + '/' + n for n in self._input_names]
# input_names to be used for this specific tower
self._feed_placehdrs_per_tower.append(
get_placeholders_by_names(phdrs, input_names))
self.reset_state()
def get_input_tensors(self):
# return placeholders for each tower
ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index]
def next_feed(self, cnt=None):
"""
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
"""
if cnt is None:
cnt = self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self.data_producer)
f = dict(zip(self._feed_placehdrs_per_tower[t], dp))
feed.update(f)
return feed
class FeedfreeInput(InputSource):
......
......@@ -3,8 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import zip
from .base import Trainer
from ..utils import logger
......@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer):
def run_step(self):
""" Feed data into the graph and run the updates. """
dp = self._input_source.next_feed()
feed = dict(zip(self.inputs, dp))
feed = self._input_source.next_feed()
self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import copy
from six.moves import zip
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['get_tensors_inputs', 'get_placeholders_by_names']
def get_tensors_inputs(placeholders, tensors, names):
"""
Quite often we want to `build_graph()` with normal tensors
(rather than placeholders).
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_placeholders_by_names(placeholders, names):
"""
Returns:
list[Tensor]: a sublist of placeholders, matching names
"""
placeholder_names = [p.name for p in placeholders]
ret = []
for name in names:
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret.append(placeholders[idx])
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment