Commit 1e7fa5f9 authored by Yuxin Wu's avatar Yuxin Wu

make InputSource subclass methods private

parent ef9e27a6
......@@ -270,7 +270,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
total -= nr_tower
# take care of the rest
while total > 0:
feed = self._input_source.next_feed(cnt=1)
# TODO XXX doesn't support remap
feed = self._input_source._next_feed(cnt=1)
self._hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(1)
total -= 1
......
......@@ -36,7 +36,6 @@ __all__ = ['InputSource',
class InputSource(object):
""" Base class for the abstract InputSource. """
@abstractmethod
def get_input_tensors(self):
"""
Returns:
......@@ -44,12 +43,20 @@ class InputSource(object):
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
"""
return self._get_input_tensors()
@abstractmethod
def _get_input_tensors(self):
pass
def setup(self, inputs_desc):
"""
Args:
inputs_desc (list[InputDesc]): list of input desc
"""
self._setup(inputs_desc)
def _setup(self, inputs_desc):
pass
def get_callbacks(self):
......@@ -57,21 +64,31 @@ class InputSource(object):
Returns:
list[Callback]: extra callbacks required by this InputSource.
"""
return self._get_callbacks()
def _get_callbacks(self):
return []
@abstractmethod
def reset_state(self):
"""
Semantics of this method has not been well defined.
"""
pass
# TODO
self._reset_state()
@abstractmethod
def _reset_state(self):
pass
def next_feed(self):
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return self._next_feed()
@abstractmethod
def _next_feed(self):
pass
def size(self):
......@@ -79,7 +96,10 @@ class InputSource(object):
Returns:
int: epoch size of the InputSource
"""
return NotImplementedError()
return self._size()
def _size(self):
raise NotImplementedError()
class ProxyInputSource(InputSource):
......@@ -90,22 +110,22 @@ class ProxyInputSource(InputSource):
assert isinstance(input, InputSource), input
self._input = input
def get_input_tensors(self):
def _get_input_tensors(self):
return self._input.get_input_tensors()
def setup(self, inputs_desc):
def _setup(self, inputs_desc):
self._input.setup(inputs_desc)
def get_callbacks(self):
def _get_callbacks(self):
return self._input.get_callbacks()
def size(self):
def _size(self):
return self._input.size()
def next_feed(self):
def _next_feed(self):
return self._input.next_feed()
def reset_state(self):
def _reset_state(self):
self._input.reset_state()
......@@ -119,22 +139,22 @@ class FeedInput(InputSource):
assert isinstance(ds, DataFlow), ds
self.ds = ds
def size(self):
def _size(self):
return self.ds.size()
def setup(self, inputs):
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self.reset_state()
def reset_state(self):
def _reset_state(self):
rds = RepeatedData(self.ds, -1)
rds.reset_state()
self.data_producer = rds.get_data()
def get_input_tensors(self):
def _get_input_tensors(self):
return self._all_placehdrs
def next_feed(self):
def _next_feed(self):
dp = next(self.data_producer)
assert len(dp) == len(self._all_placehdrs), "[FeedInput] datapoints and inputs are of different length!"
return dict(zip(self._all_placehdrs, dp))
......@@ -149,7 +169,7 @@ class DataParallelFeedInput(FeedInput):
self._tower_names = tower_names
self._nr_tower = len(tower_names)
def setup(self, inputs):
def _setup(self, inputs):
self._placehdrs_per_tower = []
for tname in self._tower_names:
# build a list of placeholders for each tower
......@@ -157,12 +177,12 @@ class DataParallelFeedInput(FeedInput):
[v.build_placeholder(prefix=tname + '/') for v in inputs])
self.reset_state()
def get_input_tensors(self):
def _get_input_tensors(self):
# return placeholders for each tower
ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index]
def next_feed(self, cnt=None):
def _next_feed(self, cnt=None):
"""
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
......@@ -181,10 +201,10 @@ class FeedfreeInput(InputSource):
""" Abstract base for input without feed,
e.g. by queue or other operations. """
def reset_state(self):
def _reset_state(self):
pass
def next_feed(self):
def _next_feed(self):
return {}
......@@ -244,10 +264,10 @@ class QueueInput(FeedfreeInput):
self.queue = queue
self.ds = ds
def size(self):
def _size(self):
return self.ds.size()
def setup(self, inputs):
def _setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...")
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self._input_placehdrs) > 0, \
......@@ -258,12 +278,12 @@ class QueueInput(FeedfreeInput):
name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def get_callbacks(self):
def _get_callbacks(self):
cb = StartProcOrThread(self.thread)
cb.chief_only = False
return [cb]
def get_input_tensors(self):
def _get_input_tensors(self):
with tf.device('/cpu:0'):
ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
......@@ -291,10 +311,10 @@ class BatchQueueInput(QueueInput):
super(BatchQueueInput, self).__init__(ds, queue)
self.batch_size = int(batch_size)
def size(self):
def _size(self):
return self.ds.size() // self.batch_size
def setup(self, inputs):
def _setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self.input_placehdrs) > 0, \
......@@ -325,7 +345,7 @@ class BatchQueueInput(QueueInput):
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def get_input_tensors(self):
def _get_input_tensors(self):
with tf.device('/cpu:0'):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
......@@ -338,6 +358,7 @@ class BatchQueueInput(QueueInput):
return ret
# TODO tensor inputs can be drained? look at the new dataset API.
class TensorInput(FeedfreeInput):
""" Input from a list of tensors, e.g. a TF data reading pipeline. """
......@@ -352,15 +373,20 @@ class TensorInput(FeedfreeInput):
if size is not None:
size = int(size)
assert size > 0
self._size = size
self._fixed_size = size
def size(self):
if self._size is None:
def _setup(self, inputs_desc):
self._desc = inputs_desc
def _size(self):
if self._fixed_size is None:
raise NotImplementedError("size of TensorInput is undefined!")
return self._size
return self._fixed_size
def get_input_tensors(self):
return self.get_tensor_fn()
def _get_input_tensors(self):
ret = self.get_tensor_fn()
assert len(ret) == len(self._desc), "{} != {}".format(len(ret), len(self._desc))
return ret
class DummyConstantInput(TensorInput):
......@@ -387,7 +413,7 @@ class DummyConstantInput(TensorInput):
return tlist
super(DummyConstantInput, self).__init__(fn)
def setup(self, inputs):
def _setup(self, inputs):
self.inputs_desc = inputs
......@@ -410,8 +436,8 @@ class ZMQInput(TensorInput):
return ret
super(ZMQInput, self).__init__(fn)
def setup(self, inputs):
self.inputs_desc = inputs
def _setup(self, inputs_desc):
self.inputs_desc = inputs_desc
assert len(self.inputs_desc) > 0, \
"ZMQInput has to be used with InputDesc!"
......@@ -454,19 +480,19 @@ class StagingInputWrapper(FeedfreeInput):
self._stage_ops = []
self._unstage_ops = []
def setup(self, inputs):
def _setup(self, inputs):
self._input.setup(inputs)
self.setup_staging_areas()
self._setup_staging_areas()
def get_callbacks(self):
def _get_callbacks(self):
cbs = self._input.get_callbacks()
cbs.append(
StagingInputWrapper.StagingCallback(
self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
self._get_stage_op(), self._get_unstage_op(), self._nr_stage))
return cbs
def setup_staging_areas(self):
def _setup_staging_areas(self):
logger.info("Setting up StagingArea for GPU prefetching ...")
for idx, device in enumerate(self._devices):
with tf.device(device):
......@@ -482,22 +508,18 @@ class StagingInputWrapper(FeedfreeInput):
vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs)
def size(self):
def _size(self):
return self._input.size()
def get_input_tensors(self):
def _get_input_tensors(self):
ctx = get_current_tower_context()
ret = self._unstage_ops[ctx.index]
return ret
@staticmethod
def get_staging_name(idx):
return 'StagingArea{}'.format(idx)
def get_stage_op(self):
def _get_stage_op(self):
return tf.group(*self._stage_ops)
def get_unstage_op(self):
def _get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs)
......@@ -514,18 +536,36 @@ def remap_input_source(input, names):
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
Returns:
InputSource:
Examples:
.. code-block:: python
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
inputs_desc = [InputDesc(tf.float32, (None,10), 'score'),
InputDesc(tf.float32, (None,20,20,3), 'label'),
InputDesc(tf.int32, (None,), 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(inputs_desc)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
"""
def __init__(self, input, names):
ProxyInputSource.__init__(self, input)
assert isinstance(names, (list, tuple)), names
self._names = tuple(names)
def setup(self, inputs):
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset)
def get_input_tensors(self):
def _get_input_tensors(self):
ret = self._input.get_input_tensors()
assert len(ret) == len(self._names)
return get_tensors_inputs(
......@@ -535,6 +575,6 @@ def remap_input_source(input, names):
# inherit oldcls so that type check in various places would work
cls = type('Remapped' + oldcls.__name__, (ProxyInputSource, oldcls), {
'__init__': __init__,
'setup': setup,
'get_input_tensors': get_input_tensors})
'_setup': _setup,
'_get_input_tensors': _get_input_tensors})
return cls(input, names)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment