Commit 20d7fe7f authored by Yuxin Wu's avatar Yuxin Wu

Improve the use of size() of InputSource

parent e2f9798a
......@@ -116,7 +116,7 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session
self._input_source.reset_state()
for _ in tqdm.trange(self._input_source.size(), **get_tqdm_kwargs()):
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
feed = self._input_source.next_feed()
self._hooked_sess.run(fetches=[], feed_dict=feed)
summary_inferencer(self.trainer, self.infs)
......@@ -252,7 +252,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
inf.before_inference()
self._input_source.reset_state()
total = self._input_source.size()
total = self._size
nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower:
......
......@@ -65,7 +65,7 @@ class TrainConfig(object):
# process data
if 'dataset' in kwargs:
dataflow = kwargs.pop('dataset')
log_deprecated("TrainConfig.dataset", "Use TrainConfig.dataflow instead.")
log_deprecated("TrainConfig.dataset", "Use TrainConfig.dataflow instead.", "2017-09-11")
if dataflow is not None:
assert data is None, "dataflow and data cannot be both presented in TrainConfig!"
self.dataflow = dataflow
......@@ -113,10 +113,6 @@ class TrainConfig(object):
assert session_config is None, "Cannot set both session_creator and session_config!"
self.session_config = session_config
if steps_per_epoch is None:
steps_per_epoch = kwargs.pop('step_per_epoch', None)
if steps_per_epoch is not None:
log_deprecated("step_per_epoch", "Use steps_per_epoch instead!", "2017-03-27")
if steps_per_epoch is None:
try:
if dataflow is not None:
......
......@@ -72,6 +72,13 @@ class InputSource(object):
"""
pass
def size(self):
"""
Returns:
int: epoch size of the InputSource
"""
return NotImplementedError()
class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """
......@@ -271,7 +278,7 @@ class QueueInput(FeedfreeInput):
return get_tensors_inputs(self.input_placehdrs, ret, self._names)
class BatchQueueInput(FeedfreeInput):
class BatchQueueInput(QueueInput):
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives batches formed by concatenating
dequeued tensors.
......@@ -285,9 +292,7 @@ class BatchQueueInput(FeedfreeInput):
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 3000.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue
self.ds = ds
super(BatchQueueInput, self).__init__(ds, queue)
self.batch_size = int(batch_size)
def size(self):
......@@ -324,11 +329,6 @@ class BatchQueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def get_callbacks(self):
cb = StartProcOrThread(self.thread)
cb.chief_only = False
return [cb]
def get_input_tensors(self):
with tf.device('/cpu:0'):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment