Commit 1e7fa5f9 authored by Yuxin Wu's avatar Yuxin Wu

make InputSource subclass methods private

parent ef9e27a6
...@@ -270,7 +270,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -270,7 +270,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
total -= nr_tower total -= nr_tower
# take care of the rest # take care of the rest
while total > 0: while total > 0:
feed = self._input_source.next_feed(cnt=1) # TODO XXX doesn't support remap
feed = self._input_source._next_feed(cnt=1)
self._hooked_sess.run(fetches=[], feed_dict=feed) self._hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(1) pbar.update(1)
total -= 1 total -= 1
......
...@@ -36,7 +36,6 @@ __all__ = ['InputSource', ...@@ -36,7 +36,6 @@ __all__ = ['InputSource',
class InputSource(object): class InputSource(object):
""" Base class for the abstract InputSource. """ """ Base class for the abstract InputSource. """
@abstractmethod
def get_input_tensors(self): def get_input_tensors(self):
""" """
Returns: Returns:
...@@ -44,12 +43,20 @@ class InputSource(object): ...@@ -44,12 +43,20 @@ class InputSource(object):
used as input of :func:`build_graph`. used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called. For non-placeholder tensors, should always create and return new tensors when called.
""" """
return self._get_input_tensors()
@abstractmethod
def _get_input_tensors(self):
pass
def setup(self, inputs_desc): def setup(self, inputs_desc):
""" """
Args: Args:
inputs_desc (list[InputDesc]): list of input desc inputs_desc (list[InputDesc]): list of input desc
""" """
self._setup(inputs_desc)
def _setup(self, inputs_desc):
pass pass
def get_callbacks(self): def get_callbacks(self):
...@@ -57,21 +64,31 @@ class InputSource(object): ...@@ -57,21 +64,31 @@ class InputSource(object):
Returns: Returns:
list[Callback]: extra callbacks required by this InputSource. list[Callback]: extra callbacks required by this InputSource.
""" """
return self._get_callbacks()
def _get_callbacks(self):
return [] return []
@abstractmethod
def reset_state(self): def reset_state(self):
""" """
Semantics of this method has not been well defined. Semantics of this method has not been well defined.
""" """
pass # TODO
self._reset_state()
@abstractmethod @abstractmethod
def _reset_state(self):
pass
def next_feed(self): def next_feed(self):
""" """
Returns: Returns:
a feed_dict of {Tensor: data}, to be used to run the steps a feed_dict of {Tensor: data}, to be used to run the steps
""" """
return self._next_feed()
@abstractmethod
def _next_feed(self):
pass pass
def size(self): def size(self):
...@@ -79,7 +96,10 @@ class InputSource(object): ...@@ -79,7 +96,10 @@ class InputSource(object):
Returns: Returns:
int: epoch size of the InputSource int: epoch size of the InputSource
""" """
return NotImplementedError() return self._size()
def _size(self):
raise NotImplementedError()
class ProxyInputSource(InputSource): class ProxyInputSource(InputSource):
...@@ -90,22 +110,22 @@ class ProxyInputSource(InputSource): ...@@ -90,22 +110,22 @@ class ProxyInputSource(InputSource):
assert isinstance(input, InputSource), input assert isinstance(input, InputSource), input
self._input = input self._input = input
def get_input_tensors(self): def _get_input_tensors(self):
return self._input.get_input_tensors() return self._input.get_input_tensors()
def setup(self, inputs_desc): def _setup(self, inputs_desc):
self._input.setup(inputs_desc) self._input.setup(inputs_desc)
def get_callbacks(self): def _get_callbacks(self):
return self._input.get_callbacks() return self._input.get_callbacks()
def size(self): def _size(self):
return self._input.size() return self._input.size()
def next_feed(self): def _next_feed(self):
return self._input.next_feed() return self._input.next_feed()
def reset_state(self): def _reset_state(self):
self._input.reset_state() self._input.reset_state()
...@@ -119,22 +139,22 @@ class FeedInput(InputSource): ...@@ -119,22 +139,22 @@ class FeedInput(InputSource):
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
def size(self): def _size(self):
return self.ds.size() return self.ds.size()
def setup(self, inputs): def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self.reset_state() self.reset_state()
def reset_state(self): def _reset_state(self):
rds = RepeatedData(self.ds, -1) rds = RepeatedData(self.ds, -1)
rds.reset_state() rds.reset_state()
self.data_producer = rds.get_data() self.data_producer = rds.get_data()
def get_input_tensors(self): def _get_input_tensors(self):
return self._all_placehdrs return self._all_placehdrs
def next_feed(self): def _next_feed(self):
dp = next(self.data_producer) dp = next(self.data_producer)
assert len(dp) == len(self._all_placehdrs), "[FeedInput] datapoints and inputs are of different length!" assert len(dp) == len(self._all_placehdrs), "[FeedInput] datapoints and inputs are of different length!"
return dict(zip(self._all_placehdrs, dp)) return dict(zip(self._all_placehdrs, dp))
...@@ -149,7 +169,7 @@ class DataParallelFeedInput(FeedInput): ...@@ -149,7 +169,7 @@ class DataParallelFeedInput(FeedInput):
self._tower_names = tower_names self._tower_names = tower_names
self._nr_tower = len(tower_names) self._nr_tower = len(tower_names)
def setup(self, inputs): def _setup(self, inputs):
self._placehdrs_per_tower = [] self._placehdrs_per_tower = []
for tname in self._tower_names: for tname in self._tower_names:
# build a list of placeholders for each tower # build a list of placeholders for each tower
...@@ -157,12 +177,12 @@ class DataParallelFeedInput(FeedInput): ...@@ -157,12 +177,12 @@ class DataParallelFeedInput(FeedInput):
[v.build_placeholder(prefix=tname + '/') for v in inputs]) [v.build_placeholder(prefix=tname + '/') for v in inputs])
self.reset_state() self.reset_state()
def get_input_tensors(self): def _get_input_tensors(self):
# return placeholders for each tower # return placeholders for each tower
ctx = get_current_tower_context() ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index] return self._placehdrs_per_tower[ctx.index]
def next_feed(self, cnt=None): def _next_feed(self, cnt=None):
""" """
Args: Args:
cnt: how many towers to feed to. Defaults to the total number of towers cnt: how many towers to feed to. Defaults to the total number of towers
...@@ -181,10 +201,10 @@ class FeedfreeInput(InputSource): ...@@ -181,10 +201,10 @@ class FeedfreeInput(InputSource):
""" Abstract base for input without feed, """ Abstract base for input without feed,
e.g. by queue or other operations. """ e.g. by queue or other operations. """
def reset_state(self): def _reset_state(self):
pass pass
def next_feed(self): def _next_feed(self):
return {} return {}
...@@ -244,10 +264,10 @@ class QueueInput(FeedfreeInput): ...@@ -244,10 +264,10 @@ class QueueInput(FeedfreeInput):
self.queue = queue self.queue = queue
self.ds = ds self.ds = ds
def size(self): def _size(self):
return self.ds.size() return self.ds.size()
def setup(self, inputs): def _setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self._input_placehdrs) > 0, \ assert len(self._input_placehdrs) > 0, \
...@@ -258,12 +278,12 @@ class QueueInput(FeedfreeInput): ...@@ -258,12 +278,12 @@ class QueueInput(FeedfreeInput):
name='input_queue') name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs) self.thread = EnqueueThread(self.queue, self.ds, self._input_placehdrs)
def get_callbacks(self): def _get_callbacks(self):
cb = StartProcOrThread(self.thread) cb = StartProcOrThread(self.thread)
cb.chief_only = False cb.chief_only = False
return [cb] return [cb]
def get_input_tensors(self): def _get_input_tensors(self):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
...@@ -291,10 +311,10 @@ class BatchQueueInput(QueueInput): ...@@ -291,10 +311,10 @@ class BatchQueueInput(QueueInput):
super(BatchQueueInput, self).__init__(ds, queue) super(BatchQueueInput, self).__init__(ds, queue)
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
def size(self): def _size(self):
return self.ds.size() // self.batch_size return self.ds.size() // self.batch_size
def setup(self, inputs): def _setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
...@@ -325,7 +345,7 @@ class BatchQueueInput(QueueInput): ...@@ -325,7 +345,7 @@ class BatchQueueInput(QueueInput):
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch) self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def get_input_tensors(self): def _get_input_tensors(self):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque') ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
...@@ -338,6 +358,7 @@ class BatchQueueInput(QueueInput): ...@@ -338,6 +358,7 @@ class BatchQueueInput(QueueInput):
return ret return ret
# TODO tensor inputs can be drained? look at the new dataset API.
class TensorInput(FeedfreeInput): class TensorInput(FeedfreeInput):
""" Input from a list of tensors, e.g. a TF data reading pipeline. """ """ Input from a list of tensors, e.g. a TF data reading pipeline. """
...@@ -352,15 +373,20 @@ class TensorInput(FeedfreeInput): ...@@ -352,15 +373,20 @@ class TensorInput(FeedfreeInput):
if size is not None: if size is not None:
size = int(size) size = int(size)
assert size > 0 assert size > 0
self._size = size self._fixed_size = size
def size(self): def _setup(self, inputs_desc):
if self._size is None: self._desc = inputs_desc
def _size(self):
if self._fixed_size is None:
raise NotImplementedError("size of TensorInput is undefined!") raise NotImplementedError("size of TensorInput is undefined!")
return self._size return self._fixed_size
def get_input_tensors(self): def _get_input_tensors(self):
return self.get_tensor_fn() ret = self.get_tensor_fn()
assert len(ret) == len(self._desc), "{} != {}".format(len(ret), len(self._desc))
return ret
class DummyConstantInput(TensorInput): class DummyConstantInput(TensorInput):
...@@ -387,7 +413,7 @@ class DummyConstantInput(TensorInput): ...@@ -387,7 +413,7 @@ class DummyConstantInput(TensorInput):
return tlist return tlist
super(DummyConstantInput, self).__init__(fn) super(DummyConstantInput, self).__init__(fn)
def setup(self, inputs): def _setup(self, inputs):
self.inputs_desc = inputs self.inputs_desc = inputs
...@@ -410,8 +436,8 @@ class ZMQInput(TensorInput): ...@@ -410,8 +436,8 @@ class ZMQInput(TensorInput):
return ret return ret
super(ZMQInput, self).__init__(fn) super(ZMQInput, self).__init__(fn)
def setup(self, inputs): def _setup(self, inputs_desc):
self.inputs_desc = inputs self.inputs_desc = inputs_desc
assert len(self.inputs_desc) > 0, \ assert len(self.inputs_desc) > 0, \
"ZMQInput has to be used with InputDesc!" "ZMQInput has to be used with InputDesc!"
...@@ -454,19 +480,19 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -454,19 +480,19 @@ class StagingInputWrapper(FeedfreeInput):
self._stage_ops = [] self._stage_ops = []
self._unstage_ops = [] self._unstage_ops = []
def setup(self, inputs): def _setup(self, inputs):
self._input.setup(inputs) self._input.setup(inputs)
self.setup_staging_areas() self._setup_staging_areas()
def get_callbacks(self): def _get_callbacks(self):
cbs = self._input.get_callbacks() cbs = self._input.get_callbacks()
cbs.append( cbs.append(
StagingInputWrapper.StagingCallback( StagingInputWrapper.StagingCallback(
self.get_stage_op(), self.get_unstage_op(), self._nr_stage)) self._get_stage_op(), self._get_unstage_op(), self._nr_stage))
return cbs return cbs
def setup_staging_areas(self): def _setup_staging_areas(self):
logger.info("Setting up StagingArea for GPU prefetching ...") logger.info("Setting up StagingArea for GPU prefetching ...")
for idx, device in enumerate(self._devices): for idx, device in enumerate(self._devices):
with tf.device(device): with tf.device(device):
...@@ -482,22 +508,18 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -482,22 +508,18 @@ class StagingInputWrapper(FeedfreeInput):
vout.set_shape(vin.get_shape()) vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs) self._unstage_ops.append(outputs)
def size(self): def _size(self):
return self._input.size() return self._input.size()
def get_input_tensors(self): def _get_input_tensors(self):
ctx = get_current_tower_context() ctx = get_current_tower_context()
ret = self._unstage_ops[ctx.index] ret = self._unstage_ops[ctx.index]
return ret return ret
@staticmethod def _get_stage_op(self):
def get_staging_name(idx):
return 'StagingArea{}'.format(idx)
def get_stage_op(self):
return tf.group(*self._stage_ops) return tf.group(*self._stage_ops)
def get_unstage_op(self): def _get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops)) all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs) return tf.group(*all_outputs)
...@@ -514,18 +536,36 @@ def remap_input_source(input, names): ...@@ -514,18 +536,36 @@ def remap_input_source(input, names):
input(InputSource): a :class:`InputSource`, whose tensors will get mapped. input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors names(list[str]): list of input names corresponding to the tensors
produced by ``input``. produced by ``input``.
Returns:
InputSource:
Examples:
.. code-block:: python
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
inputs_desc = [InputDesc(tf.float32, (None,10), 'score'),
InputDesc(tf.float32, (None,20,20,3), 'label'),
InputDesc(tf.int32, (None,), 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(inputs_desc)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
""" """
def __init__(self, input, names): def __init__(self, input, names):
ProxyInputSource.__init__(self, input) ProxyInputSource.__init__(self, input)
assert isinstance(names, (list, tuple)), names assert isinstance(names, (list, tuple)), names
self._names = tuple(names) self._names = tuple(names)
def setup(self, inputs): def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
inputs_subset = get_sublist_by_names(inputs, self._names) inputs_subset = get_sublist_by_names(inputs, self._names)
self._input.setup(inputs_subset) self._input.setup(inputs_subset)
def get_input_tensors(self): def _get_input_tensors(self):
ret = self._input.get_input_tensors() ret = self._input.get_input_tensors()
assert len(ret) == len(self._names) assert len(ret) == len(self._names)
return get_tensors_inputs( return get_tensors_inputs(
...@@ -535,6 +575,6 @@ def remap_input_source(input, names): ...@@ -535,6 +575,6 @@ def remap_input_source(input, names):
# inherit oldcls so that type check in various places would work # inherit oldcls so that type check in various places would work
cls = type('Remapped' + oldcls.__name__, (ProxyInputSource, oldcls), { cls = type('Remapped' + oldcls.__name__, (ProxyInputSource, oldcls), {
'__init__': __init__, '__init__': __init__,
'setup': setup, '_setup': _setup,
'get_input_tensors': get_input_tensors}) '_get_input_tensors': _get_input_tensors})
return cls(input, names) return cls(input, names)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment