Commit adf51f22 authored by Yuxin Wu's avatar Yuxin Wu

Allow remapping on every InputSource, therefore remove the old 'names' options

parent 20d7fe7f
......@@ -16,7 +16,8 @@ from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..train.input_source import FeedInput, DataParallelFeedInput, FeedfreeInput
from ..train.input_source import (
FeedInput, DataParallelFeedInput, FeedfreeInput, InputSource)
from ..predict import PredictorTowerBuilder
from .base import Callback
......@@ -128,16 +129,15 @@ class InferenceRunner(InferenceRunnerBase):
:class:`DataFlow`.
"""
def __init__(self, ds, infs, input_names=None, extra_hooks=None):
def __init__(self, input, infs, extra_hooks=None):
"""
Args:
ds (DataFlow): the DataFlow to run inferencer on.
input (FeedInput or DataFlow): the FeedInput, or the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
"""
assert isinstance(ds, DataFlow), ds
input = FeedInput(ds, input_names)
if isinstance(input, DataFlow):
input = FeedInput(input)
assert isinstance(input, FeedInput), input
super(InferenceRunner, self).__init__(
input, infs, prefix='', extra_hooks=extra_hooks)
......@@ -158,7 +158,6 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
Args:
input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
......@@ -180,11 +179,20 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
class DataParallelInferenceRunner(InferenceRunnerBase):
def __init__(self, ds, infs, gpus, input_names=None):
self._tower_names = [TowerContext.get_predict_tower_name(k)
for k in range(len(gpus))]
input = DataParallelFeedInput(
ds, self._tower_names, input_names=input_names)
"""
Not tested. Don't use.
"""
# TODO some scripts to test
def __init__(self, input, infs, gpus):
"""
Args:
input (DataParallelFeedInput or DataFlow)
"""
if isinstance(input, DataFlow):
tower_names = [TowerContext.get_predict_tower_name(k) for k in range(len(gpus))]
input = DataParallelFeedInput(input, tower_names)
assert isinstance(input, InputSource), input
super(DataParallelInferenceRunner, self).__init__(input, infs)
self._gpus = gpus
......
......@@ -14,7 +14,7 @@ from abc import ABCMeta, abstractmethod
import six
from six.moves import range, zip
from .utils import get_placeholders_by_names, get_tensors_inputs
from .utils import get_sublist_by_names, get_tensors_inputs
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
......@@ -30,7 +30,7 @@ __all__ = ['InputSource',
'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'ZMQInput', 'DummyConstantInput', 'TensorInput',
'StagingInputWrapper', 'ReorderInputSource']
'StagingInputWrapper', 'remap_input_source']
@six.add_metaclass(ABCMeta)
......@@ -82,29 +82,19 @@ class InputSource(object):
class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds, input_names=None):
def __init__(self, ds):
"""
Args:
ds (DataFlow): the input DataFlow.
input_names (list[str]): input names this DataFlow maps to
"""
assert isinstance(ds, DataFlow), ds
if input_names is not None:
assert isinstance(input_names, (list, tuple)), input_names
self.ds = ds
self._input_names = input_names
def size(self):
return self.ds.size()
def setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._input_names is None:
self._placehdrs_to_feed = self._all_placehdrs
else:
self._placehdrs_to_feed = get_placeholders_by_names(
self._all_placehdrs, self._input_names)
self.reset_state()
def reset_state(self):
......@@ -117,37 +107,25 @@ class FeedInput(InputSource):
def next_feed(self):
dp = next(self.data_producer)
return dict(zip(self._placehdrs_to_feed, dp))
assert len(dp) == len(self._all_placehdrs), "[FeedInput] datapoints and inputs are of different length!"
return dict(zip(self._all_placehdrs, dp))
class DataParallelFeedInput(FeedInput):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
def __init__(self, ds, tower_names, input_names=None):
super(DataParallelFeedInput, self).__init__(ds, input_names)
def __init__(self, ds, tower_names):
super(DataParallelFeedInput, self).__init__(ds)
self._tower_names = tower_names
self._nr_tower = len(tower_names)
def setup(self, inputs):
self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = []
for tname in self._tower_names:
# build a list of placeholders for each tower
self._placehdrs_per_tower.append(
[v.build_placeholder(prefix=tname + '/') for v in inputs])
# apply input mapping and store results in feed_placehdrs_per_tower
if self._input_names is None:
self._feed_placehdrs_per_tower = self._placehdrs_per_tower
else:
for phdrs, tname in zip(
self._placehdrs_per_tower, self._tower_names):
input_names = [tname + '/' + n for n in self._input_names]
# input_names to be used for this specific tower
self._feed_placehdrs_per_tower.append(
get_placeholders_by_names(phdrs, input_names))
print(self._feed_placehdrs_per_tower[-1])
self.reset_state()
def get_input_tensors(self):
......@@ -165,7 +143,7 @@ class DataParallelFeedInput(FeedInput):
feed = {}
for t in range(cnt):
dp = next(self.data_producer)
f = dict(zip(self._feed_placehdrs_per_tower[t], dp))
f = dict(zip(self._placehdrs_per_tower[t], dp))
feed.update(f)
return feed
......@@ -175,7 +153,6 @@ class FeedfreeInput(InputSource):
e.g. by queue or other operations. """
def reset_state(self):
# TODO no state to reset
pass
def next_feed(self):
......@@ -226,38 +203,31 @@ class QueueInput(FeedfreeInput):
And the model receives dequeued tensors.
"""
def __init__(self, ds, queue=None, names=None):
def __init__(self, ds, queue=None):
"""
Args:
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50.
names(list[str]): list of input names corresponding to the dataflow.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue
self.ds = ds
self._names = names
def size(self):
return self.ds.size()
# TODO use input data mapping. not all placeholders are needed
def setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._names is None:
self._queue_feedpoint = self.input_placehdrs
else:
self._queue_feedpoint = get_placeholders_by_names(self.input_placehdrs, self._names)
assert len(self._queue_feedpoint) > 0, \
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self._input_placehdrs) > 0, \
"QueueInput has to be used with some inputs!"
if self.queue is None:
self.queue = tf.FIFOQueue(
50, [x.dtype for x in self._queue_feedpoint],
50, [x.dtype for x in self._input_placehdrs],
name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self._queue_feedpoint)
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def get_callbacks(self):
cb = StartProcOrThread(self.thread)
......@@ -269,13 +239,10 @@ class QueueInput(FeedfreeInput):
ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self._queue_feedpoint)
for qv, v in zip(ret, self._queue_feedpoint):
assert len(ret) == len(self._input_placehdrs)
for qv, v in zip(ret, self._input_placehdrs):
qv.set_shape(v.get_shape())
if self._names is None:
return ret
else:
return get_tensors_inputs(self.input_placehdrs, ret, self._names)
class BatchQueueInput(QueueInput):
......@@ -351,8 +318,6 @@ class TensorInput(FeedfreeInput):
get_tensor_fn: a function which returns a list of input tensors
when called. It will be called under a TowerContext.
size(int): size of this input. Use None to leave it undefined.
input_names (list[str]): input names the tensors maps to. Defaults
to be all the inputs of the model.
"""
self.get_tensor_fn = get_tensor_fn
if size is not None:
......@@ -397,7 +362,6 @@ class DummyConstantInput(TensorInput):
self.inputs_desc = inputs
# TODO doesn't support remapping
class ZMQInput(TensorInput):
"""
Not well implemented yet. Don't use.
......@@ -511,29 +475,35 @@ class StagingInputWrapper(FeedfreeInput):
return tf.group(*all_outputs)
class ReorderInputSource(FeedfreeInput):
# TODO dynamically generate inheritance
# TODO make it a function, not a class
class remap_input_source(FeedInput, FeedfreeInput):
"""
When an InputSource only maps to a subset of the InputDesc of the model,
wrap it with :class:`ReorderInputSource`.
When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
"""
def __init__(self, input, names):
"""
Args:
input(TensorInput): a TensorInput, whose tensors will get mapped.
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
assert isinstance(input, TensorInput), input
assert isinstance(input, InputSource), input
self._input = input
assert isinstance(names, (list, tuple)), names
self._names = names
self._names = tuple(names)
def size(self):
return self._input.size()
def setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup(inputs)
inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset)
def get_callbacks(self):
return self._input.get_callbacks()
......@@ -541,7 +511,11 @@ class ReorderInputSource(FeedfreeInput):
def reset_state(self):
self._input.reset_state()
def next_feed(self):
return self._input.next_feed()
def get_input_tensors(self):
ret = self._input.get_input_tensors()
assert len(ret) == len(self._names)
return get_tensors_inputs(
self._all_placehdrs, ret, self._names)
......@@ -8,14 +8,11 @@ from six.moves import zip
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['get_tensors_inputs', 'get_placeholders_by_names']
__all__ = ['get_tensors_inputs', 'get_sublist_by_names']
def get_tensors_inputs(placeholders, tensors, names):
"""
Quite often we want to `build_graph()` with normal tensors
(rather than placeholders).
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
......@@ -41,19 +38,22 @@ def get_tensors_inputs(placeholders, tensors, names):
return ret
def get_placeholders_by_names(placeholders, names):
def get_sublist_by_names(lst, names):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list[Tensor]: a sublist of placeholders, matching names
list: a sublist of objects, matching names
"""
placeholder_names = [p.name for p in placeholders]
orig_names = [p.name for p in lst]
ret = []
for name in names:
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
idx = orig_names.index(name)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
logger.error("Name {} doesn't appear in lst {}!".format(
name, str(orig_names)))
raise
ret.append(placeholders[idx])
ret.append(lst[idx])
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment