Commit c2a38de9 authored by Yuxin Wu's avatar Yuxin Wu

Let InputSource.reset always get called. Add TFDatasetInput (#397)

parent f2697f69
...@@ -26,6 +26,7 @@ __all__ = ['PlaceholderInput', 'FeedInput', 'DataParallelFeedInput', ...@@ -26,6 +26,7 @@ __all__ = ['PlaceholderInput', 'FeedInput', 'DataParallelFeedInput',
'FeedfreeInput', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'ZMQInput', 'DummyConstantInput', 'TensorInput', 'ZMQInput', 'DummyConstantInput', 'TensorInput',
'TFDatasetInput',
'StagingInputWrapper'] 'StagingInputWrapper']
...@@ -86,7 +87,6 @@ class FeedInput(InputSource): ...@@ -86,7 +87,6 @@ class FeedInput(InputSource):
def _setup(self, inputs): def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder(prefix='') for v in inputs] self._all_placehdrs = [v.build_placeholder(prefix='') for v in inputs]
self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs) self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs)
self.reset_state()
def _get_input_tensors(self): def _get_input_tensors(self):
return self._all_placehdrs return self._all_placehdrs
...@@ -135,7 +135,6 @@ class DataParallelFeedInput(FeedInput): ...@@ -135,7 +135,6 @@ class DataParallelFeedInput(FeedInput):
self._placehdrs_per_tower.append( self._placehdrs_per_tower.append(
[v.build_placeholder(prefix=tname + '/') for v in inputs]) [v.build_placeholder(prefix=tname + '/') for v in inputs])
self._cb = self._DataParallelFeedCallback(self._iter_ds, self._placehdrs_per_tower) self._cb = self._DataParallelFeedCallback(self._iter_ds, self._placehdrs_per_tower)
self.reset_state()
def _get_input_tensors(self): def _get_input_tensors(self):
# return placeholders for each tower # return placeholders for each tower
...@@ -415,6 +414,47 @@ class ZMQInput(TensorInput): ...@@ -415,6 +414,47 @@ class ZMQInput(TensorInput):
"ZMQInput has to be used with InputDesc!" "ZMQInput has to be used with InputDesc!"
class TFDatasetInput(FeedfreeInput):
"""
Use a :class:`tf.contrib.data.Dataset` instance as input.
Note:
In training, the dataset should be infinite (use :func:`repeat()`).
"""
def __init__(self, dataset):
"""
Args:
dataset (tf.contrib.data.Dataset):
"""
self._dataset = dataset
def _setup(self, inputs_desc):
self._desc = inputs_desc
types = self._dataset.output_types
desc_types = tuple([k.type for k in inputs_desc])
assert len(types) == len(desc_types), \
"Dataset and InputDesc has different length! {} != {}".format(
len(types), len(desc_types))
assert types == desc_types, \
"Types of dataset and InputDesc don't match! {} != {}".format(
str(types), str(desc_types))
shapes = self._dataset.output_shapes
desc_shapes = [k.shape for k in inputs_desc]
for idx, (s1, s2) in enumerate(zip(shapes, desc_shapes)):
s2 = tf.TensorShape(s2)
assert s2.is_compatible_with(s1), \
"InputDesc '{}' has incompatible shape with dataset! {} vs {}".format(
inputs_desc[idx].name, s2, s1)
self._iterator = self._dataset.make_initializable_iterator()
self._init_op = self._iterator.initializer
def _reset_state(self):
self._init_op.run()
def _get_input_tensors(self):
return self._iterator.get_next()
class StagingInputWrapper(FeedfreeInput): class StagingInputWrapper(FeedfreeInput):
""" """
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs). A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs).
......
...@@ -9,6 +9,7 @@ import tensorflow as tf ...@@ -9,6 +9,7 @@ import tensorflow as tf
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ._utils import get_sublist_by_names, get_tensors_inputs from ._utils import get_sublist_by_names, get_tensors_inputs
from ..callbacks.base import CallbackFactory
__all__ = ['InputSource', 'remap_input_source'] __all__ = ['InputSource', 'remap_input_source']
...@@ -56,14 +57,18 @@ class InputSource(object): ...@@ -56,14 +57,18 @@ class InputSource(object):
Returns: Returns:
list[Callback]: extra callbacks needed by this InputSource. list[Callback]: extra callbacks needed by this InputSource.
""" """
return self._get_callbacks() return [CallbackFactory(
before_train=lambda _: self.reset_state())] + self._get_callbacks()
def _get_callbacks(self): def _get_callbacks(self):
return [] return []
def reset_state(self): def reset_state(self):
""" """
Reinitialize this InputSource. Initialize/reinitialize this InputSource.
For training, it will get called by the trainer in `before_train` callbacks.
For inference, the :class:`InferenceRunner` will call it each time it does is triggered.
""" """
self._reset_state() self._reset_state()
......
...@@ -125,7 +125,8 @@ class TrainConfig(object): ...@@ -125,7 +125,8 @@ class TrainConfig(object):
else: else:
raise NotImplementedError() raise NotImplementedError()
except NotImplementedError: except NotImplementedError:
logger.exception("You must set `TrainConfig(steps_per_epoch)` if data.size() is not available.") logger.error("You must set `TrainConfig(steps_per_epoch)` if data.size() is not available.")
raise
else: else:
steps_per_epoch = int(steps_per_epoch) steps_per_epoch = int(steps_per_epoch)
self.steps_per_epoch = steps_per_epoch self.steps_per_epoch = steps_per_epoch
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment