Commit adf51f22 authored by Yuxin Wu's avatar Yuxin Wu

Allow remapping on every InputSource, therefore remove the old 'names' options

parent 20d7fe7f
...@@ -16,7 +16,8 @@ from ..utils import logger, get_tqdm_kwargs ...@@ -16,7 +16,8 @@ from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..train.input_source import FeedInput, DataParallelFeedInput, FeedfreeInput from ..train.input_source import (
FeedInput, DataParallelFeedInput, FeedfreeInput, InputSource)
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Callback from .base import Callback
...@@ -128,16 +129,15 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -128,16 +129,15 @@ class InferenceRunner(InferenceRunnerBase):
:class:`DataFlow`. :class:`DataFlow`.
""" """
def __init__(self, ds, infs, input_names=None, extra_hooks=None): def __init__(self, input, infs, extra_hooks=None):
""" """
Args: Args:
ds (DataFlow): the DataFlow to run inferencer on. input (FeedInput or DataFlow): the FeedInput, or the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances. infs (list): a list of `Inferencer` instances.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
""" """
assert isinstance(ds, DataFlow), ds if isinstance(input, DataFlow):
input = FeedInput(ds, input_names) input = FeedInput(input)
assert isinstance(input, FeedInput), input
super(InferenceRunner, self).__init__( super(InferenceRunner, self).__init__(
input, infs, prefix='', extra_hooks=extra_hooks) input, infs, prefix='', extra_hooks=extra_hooks)
...@@ -158,7 +158,6 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -158,7 +158,6 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
Args: Args:
input (FeedfreeInput): the input to use. Must have ``size()``. input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run. infs (list): list of :class:`Inferencer` to run.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used. differently if more than one :class:`FeedfreeInferenceRunner` are used.
""" """
...@@ -180,11 +179,20 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -180,11 +179,20 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
class DataParallelInferenceRunner(InferenceRunnerBase): class DataParallelInferenceRunner(InferenceRunnerBase):
def __init__(self, ds, infs, gpus, input_names=None): """
self._tower_names = [TowerContext.get_predict_tower_name(k) Not tested. Don't use.
for k in range(len(gpus))] """
input = DataParallelFeedInput( # TODO some scripts to test
ds, self._tower_names, input_names=input_names) def __init__(self, input, infs, gpus):
"""
Args:
input (DataParallelFeedInput or DataFlow)
"""
if isinstance(input, DataFlow):
tower_names = [TowerContext.get_predict_tower_name(k) for k in range(len(gpus))]
input = DataParallelFeedInput(input, tower_names)
assert isinstance(input, InputSource), input
super(DataParallelInferenceRunner, self).__init__(input, infs) super(DataParallelInferenceRunner, self).__init__(input, infs)
self._gpus = gpus self._gpus = gpus
......
...@@ -14,7 +14,7 @@ from abc import ABCMeta, abstractmethod ...@@ -14,7 +14,7 @@ from abc import ABCMeta, abstractmethod
import six import six
from six.moves import range, zip from six.moves import range, zip
from .utils import get_placeholders_by_names, get_tensors_inputs from .utils import get_sublist_by_names, get_tensors_inputs
from ..dataflow import DataFlow, RepeatedData from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
...@@ -30,7 +30,7 @@ __all__ = ['InputSource', ...@@ -30,7 +30,7 @@ __all__ = ['InputSource',
'FeedfreeInput', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'ZMQInput', 'DummyConstantInput', 'TensorInput', 'ZMQInput', 'DummyConstantInput', 'TensorInput',
'StagingInputWrapper', 'ReorderInputSource'] 'StagingInputWrapper', 'remap_input_source']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -82,29 +82,19 @@ class InputSource(object): ...@@ -82,29 +82,19 @@ class InputSource(object):
class FeedInput(InputSource): class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """ """ Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds, input_names=None): def __init__(self, ds):
""" """
Args: Args:
ds (DataFlow): the input DataFlow. ds (DataFlow): the input DataFlow.
input_names (list[str]): input names this DataFlow maps to
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
if input_names is not None:
assert isinstance(input_names, (list, tuple)), input_names
self.ds = ds self.ds = ds
self._input_names = input_names
def size(self): def size(self):
return self.ds.size() return self.ds.size()
def setup(self, inputs): def setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._input_names is None:
self._placehdrs_to_feed = self._all_placehdrs
else:
self._placehdrs_to_feed = get_placeholders_by_names(
self._all_placehdrs, self._input_names)
self.reset_state() self.reset_state()
def reset_state(self): def reset_state(self):
...@@ -117,37 +107,25 @@ class FeedInput(InputSource): ...@@ -117,37 +107,25 @@ class FeedInput(InputSource):
def next_feed(self): def next_feed(self):
dp = next(self.data_producer) dp = next(self.data_producer)
return dict(zip(self._placehdrs_to_feed, dp)) assert len(dp) == len(self._all_placehdrs), "[FeedInput] datapoints and inputs are of different length!"
return dict(zip(self._all_placehdrs, dp))
class DataParallelFeedInput(FeedInput): class DataParallelFeedInput(FeedInput):
""" """
Input by feeding k datapoints to k copies of placeholders located on k towers. Input by feeding k datapoints to k copies of placeholders located on k towers.
""" """
def __init__(self, ds, tower_names, input_names=None): def __init__(self, ds, tower_names):
super(DataParallelFeedInput, self).__init__(ds, input_names) super(DataParallelFeedInput, self).__init__(ds)
self._tower_names = tower_names self._tower_names = tower_names
self._nr_tower = len(tower_names) self._nr_tower = len(tower_names)
def setup(self, inputs): def setup(self, inputs):
self._placehdrs_per_tower = [] self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = []
for tname in self._tower_names: for tname in self._tower_names:
# build a list of placeholders for each tower # build a list of placeholders for each tower
self._placehdrs_per_tower.append( self._placehdrs_per_tower.append(
[v.build_placeholder(prefix=tname + '/') for v in inputs]) [v.build_placeholder(prefix=tname + '/') for v in inputs])
# apply input mapping and store results in feed_placehdrs_per_tower
if self._input_names is None:
self._feed_placehdrs_per_tower = self._placehdrs_per_tower
else:
for phdrs, tname in zip(
self._placehdrs_per_tower, self._tower_names):
input_names = [tname + '/' + n for n in self._input_names]
# input_names to be used for this specific tower
self._feed_placehdrs_per_tower.append(
get_placeholders_by_names(phdrs, input_names))
print(self._feed_placehdrs_per_tower[-1])
self.reset_state() self.reset_state()
def get_input_tensors(self): def get_input_tensors(self):
...@@ -165,7 +143,7 @@ class DataParallelFeedInput(FeedInput): ...@@ -165,7 +143,7 @@ class DataParallelFeedInput(FeedInput):
feed = {} feed = {}
for t in range(cnt): for t in range(cnt):
dp = next(self.data_producer) dp = next(self.data_producer)
f = dict(zip(self._feed_placehdrs_per_tower[t], dp)) f = dict(zip(self._placehdrs_per_tower[t], dp))
feed.update(f) feed.update(f)
return feed return feed
...@@ -175,7 +153,6 @@ class FeedfreeInput(InputSource): ...@@ -175,7 +153,6 @@ class FeedfreeInput(InputSource):
e.g. by queue or other operations. """ e.g. by queue or other operations. """
def reset_state(self): def reset_state(self):
# TODO no state to reset
pass pass
def next_feed(self): def next_feed(self):
...@@ -226,38 +203,31 @@ class QueueInput(FeedfreeInput): ...@@ -226,38 +203,31 @@ class QueueInput(FeedfreeInput):
And the model receives dequeued tensors. And the model receives dequeued tensors.
""" """
def __init__(self, ds, queue=None, names=None): def __init__(self, ds, queue=None):
""" """
Args: Args:
ds(DataFlow): the input DataFlow. ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model. should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50. Defaults to a FIFO queue of size 50.
names(list[str]): list of input names corresponding to the dataflow.
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.queue = queue self.queue = queue
self.ds = ds self.ds = ds
self._names = names
def size(self): def size(self):
return self.ds.size() return self.ds.size()
# TODO use input data mapping. not all placeholders are needed
def setup(self, inputs): def setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._names is None: assert len(self._input_placehdrs) > 0, \
self._queue_feedpoint = self.input_placehdrs
else:
self._queue_feedpoint = get_placeholders_by_names(self.input_placehdrs, self._names)
assert len(self._queue_feedpoint) > 0, \
"QueueInput has to be used with some inputs!" "QueueInput has to be used with some inputs!"
if self.queue is None: if self.queue is None:
self.queue = tf.FIFOQueue( self.queue = tf.FIFOQueue(
50, [x.dtype for x in self._queue_feedpoint], 50, [x.dtype for x in self._input_placehdrs],
name='input_queue') name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self._queue_feedpoint) self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def get_callbacks(self): def get_callbacks(self):
cb = StartProcOrThread(self.thread) cb = StartProcOrThread(self.thread)
...@@ -269,13 +239,10 @@ class QueueInput(FeedfreeInput): ...@@ -269,13 +239,10 @@ class QueueInput(FeedfreeInput):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
assert len(ret) == len(self._queue_feedpoint) assert len(ret) == len(self._input_placehdrs)
for qv, v in zip(ret, self._queue_feedpoint): for qv, v in zip(ret, self._input_placehdrs):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
if self._names is None:
return ret return ret
else:
return get_tensors_inputs(self.input_placehdrs, ret, self._names)
class BatchQueueInput(QueueInput): class BatchQueueInput(QueueInput):
...@@ -351,8 +318,6 @@ class TensorInput(FeedfreeInput): ...@@ -351,8 +318,6 @@ class TensorInput(FeedfreeInput):
get_tensor_fn: a function which returns a list of input tensors get_tensor_fn: a function which returns a list of input tensors
when called. It will be called under a TowerContext. when called. It will be called under a TowerContext.
size(int): size of this input. Use None to leave it undefined. size(int): size of this input. Use None to leave it undefined.
input_names (list[str]): input names the tensors maps to. Defaults
to be all the inputs of the model.
""" """
self.get_tensor_fn = get_tensor_fn self.get_tensor_fn = get_tensor_fn
if size is not None: if size is not None:
...@@ -397,7 +362,6 @@ class DummyConstantInput(TensorInput): ...@@ -397,7 +362,6 @@ class DummyConstantInput(TensorInput):
self.inputs_desc = inputs self.inputs_desc = inputs
# TODO doesn't support remapping
class ZMQInput(TensorInput): class ZMQInput(TensorInput):
""" """
Not well implemented yet. Don't use. Not well implemented yet. Don't use.
...@@ -511,29 +475,35 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -511,29 +475,35 @@ class StagingInputWrapper(FeedfreeInput):
return tf.group(*all_outputs) return tf.group(*all_outputs)
class ReorderInputSource(FeedfreeInput): # TODO dynamically generate inheritance
# TODO make it a function, not a class
class remap_input_source(FeedInput, FeedfreeInput):
""" """
When an InputSource only maps to a subset of the InputDesc of the model, When you have some :class:`InputSource` which doesn't match the inputs in
wrap it with :class:`ReorderInputSource`. your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
""" """
def __init__(self, input, names): def __init__(self, input, names):
""" """
Args: Args:
input(TensorInput): a TensorInput, whose tensors will get mapped. input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors names(list[str]): list of input names corresponding to the tensors
produced by ``input``. produced by ``input``.
""" """
assert isinstance(input, TensorInput), input assert isinstance(input, InputSource), input
self._input = input self._input = input
assert isinstance(names, (list, tuple)), names assert isinstance(names, (list, tuple)), names
self._names = names self._names = tuple(names)
def size(self): def size(self):
return self._input.size() return self._input.size()
def setup(self, inputs): def setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup(inputs) inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset)
def get_callbacks(self): def get_callbacks(self):
return self._input.get_callbacks() return self._input.get_callbacks()
...@@ -541,7 +511,11 @@ class ReorderInputSource(FeedfreeInput): ...@@ -541,7 +511,11 @@ class ReorderInputSource(FeedfreeInput):
def reset_state(self): def reset_state(self):
self._input.reset_state() self._input.reset_state()
def next_feed(self):
return self._input.next_feed()
def get_input_tensors(self): def get_input_tensors(self):
ret = self._input.get_input_tensors() ret = self._input.get_input_tensors()
assert len(ret) == len(self._names)
return get_tensors_inputs( return get_tensors_inputs(
self._all_placehdrs, ret, self._names) self._all_placehdrs, ret, self._names)
...@@ -8,14 +8,11 @@ from six.moves import zip ...@@ -8,14 +8,11 @@ from six.moves import zip
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
__all__ = ['get_tensors_inputs', 'get_placeholders_by_names'] __all__ = ['get_tensors_inputs', 'get_sublist_by_names']
def get_tensors_inputs(placeholders, tensors, names): def get_tensors_inputs(placeholders, tensors, names):
""" """
Quite often we want to `build_graph()` with normal tensors
(rather than placeholders).
Args: Args:
placeholders (list[Tensor]): placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor tensors (list[Tensor]): list of tf.Tensor
...@@ -41,19 +38,22 @@ def get_tensors_inputs(placeholders, tensors, names): ...@@ -41,19 +38,22 @@ def get_tensors_inputs(placeholders, tensors, names):
return ret return ret
def get_placeholders_by_names(placeholders, names): def get_sublist_by_names(lst, names):
""" """
Args:
lst (list): list of objects with "name" property.
Returns: Returns:
list[Tensor]: a sublist of placeholders, matching names list: a sublist of objects, matching names
""" """
placeholder_names = [p.name for p in placeholders] orig_names = [p.name for p in lst]
ret = [] ret = []
for name in names: for name in names:
tensorname = get_op_tensor_name(name)[1]
try: try:
idx = placeholder_names.index(tensorname) idx = orig_names.index(name)
except ValueError: except ValueError:
logger.error("Name {} is not a model input!".format(tensorname)) logger.error("Name {} doesn't appear in lst {}!".format(
name, str(orig_names)))
raise raise
ret.append(placeholders[idx]) ret.append(lst[idx])
return ret return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment