Commit 20d7fe7f authored by Yuxin Wu's avatar Yuxin Wu

Improve the use of size() of InputSource

parent e2f9798a
...@@ -116,7 +116,7 @@ class InferenceRunnerBase(Callback): ...@@ -116,7 +116,7 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
self._input_source.reset_state() self._input_source.reset_state()
for _ in tqdm.trange(self._input_source.size(), **get_tqdm_kwargs()): for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
feed = self._input_source.next_feed() feed = self._input_source.next_feed()
self._hooked_sess.run(fetches=[], feed_dict=feed) self._hooked_sess.run(fetches=[], feed_dict=feed)
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
...@@ -252,7 +252,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -252,7 +252,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
inf.before_inference() inf.before_inference()
self._input_source.reset_state() self._input_source.reset_state()
total = self._input_source.size() total = self._size
nr_tower = len(self._gpus) nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar: with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower: while total >= nr_tower:
......
...@@ -65,7 +65,7 @@ class TrainConfig(object): ...@@ -65,7 +65,7 @@ class TrainConfig(object):
# process data # process data
if 'dataset' in kwargs: if 'dataset' in kwargs:
dataflow = kwargs.pop('dataset') dataflow = kwargs.pop('dataset')
log_deprecated("TrainConfig.dataset", "Use TrainConfig.dataflow instead.") log_deprecated("TrainConfig.dataset", "Use TrainConfig.dataflow instead.", "2017-09-11")
if dataflow is not None: if dataflow is not None:
assert data is None, "dataflow and data cannot be both presented in TrainConfig!" assert data is None, "dataflow and data cannot be both presented in TrainConfig!"
self.dataflow = dataflow self.dataflow = dataflow
...@@ -113,10 +113,6 @@ class TrainConfig(object): ...@@ -113,10 +113,6 @@ class TrainConfig(object):
assert session_config is None, "Cannot set both session_creator and session_config!" assert session_config is None, "Cannot set both session_creator and session_config!"
self.session_config = session_config self.session_config = session_config
if steps_per_epoch is None:
steps_per_epoch = kwargs.pop('step_per_epoch', None)
if steps_per_epoch is not None:
log_deprecated("step_per_epoch", "Use steps_per_epoch instead!", "2017-03-27")
if steps_per_epoch is None: if steps_per_epoch is None:
try: try:
if dataflow is not None: if dataflow is not None:
......
...@@ -72,6 +72,13 @@ class InputSource(object): ...@@ -72,6 +72,13 @@ class InputSource(object):
""" """
pass pass
def size(self):
"""
Returns:
int: epoch size of the InputSource
"""
return NotImplementedError()
class FeedInput(InputSource): class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """ """ Input by iterating over a DataFlow and feed datapoints. """
...@@ -271,7 +278,7 @@ class QueueInput(FeedfreeInput): ...@@ -271,7 +278,7 @@ class QueueInput(FeedfreeInput):
return get_tensors_inputs(self.input_placehdrs, ret, self._names) return get_tensors_inputs(self.input_placehdrs, ret, self._names)
class BatchQueueInput(FeedfreeInput): class BatchQueueInput(QueueInput):
""" Enqueue datapoints from a DataFlow to a TF queue. """ Enqueue datapoints from a DataFlow to a TF queue.
And the model receives batches formed by concatenating And the model receives batches formed by concatenating
dequeued tensors. dequeued tensors.
...@@ -285,9 +292,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -285,9 +292,7 @@ class BatchQueueInput(FeedfreeInput):
should match the corresponding InputDesc of the model. should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 3000. Defaults to a FIFO queue of size 3000.
""" """
assert isinstance(ds, DataFlow), ds super(BatchQueueInput, self).__init__(ds, queue)
self.queue = queue
self.ds = ds
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
def size(self): def size(self):
...@@ -324,11 +329,6 @@ class BatchQueueInput(FeedfreeInput): ...@@ -324,11 +329,6 @@ class BatchQueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch) self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def get_callbacks(self):
cb = StartProcOrThread(self.thread)
cb.chief_only = False
return [cb]
def get_input_tensors(self): def get_input_tensors(self):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque') ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment