Commit e2f9798a authored by Yuxin Wu's avatar Yuxin Wu

replace setup_training by get_callbacks in InputSource

parent 14c564cc
......@@ -84,6 +84,8 @@ class InferenceRunnerBase(Callback):
def _setup_graph(self):
self._input_source.setup(self.trainer.model.get_inputs_desc())
assert len(self._input_source.get_callbacks()) == 0, \
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0]
......@@ -189,6 +191,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _setup_graph(self):
model = self.trainer.model
self._input_source.setup(model.get_inputs_desc())
assert len(self._input_source.get_callbacks()) == 0, \
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# build graph
def build_tower(k):
......
......@@ -37,7 +37,10 @@ class FeedfreeTrainerBase(Trainer):
def _setup(self):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_source.setup_training(self)
self._input_source.setup(self.model.get_inputs_desc())
input_callbacks = self._input_source.get_callbacks()
for cb in input_callbacks:
self.register_callback(cb)
def run_step(self):
""" Simply run ``self.train_op``."""
......
......@@ -53,8 +53,12 @@ class InputSource(object):
"""
pass
def setup_training(self, trainer):
self.setup(trainer.model.get_inputs_desc())
def get_callbacks(self):
"""
Returns:
list[Callback]: extra callbacks required by this InputSource.
"""
return []
@abstractmethod
def reset_state(self):
......@@ -248,11 +252,10 @@ class QueueInput(FeedfreeInput):
name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self._queue_feedpoint)
def setup_training(self, trainer):
super(QueueInput, self).setup_training(trainer)
def get_callbacks(self):
cb = StartProcOrThread(self.thread)
cb.chief_only = False
trainer.register_callback(cb)
return [cb]
def get_input_tensors(self):
with tf.device('/cpu:0'):
......@@ -321,9 +324,10 @@ class BatchQueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def setup_training(self, trainer):
super(BatchQueueInput, self).setup_training(trainer)
trainer.register_callback(StartProcOrThread(self.thread))
def get_callbacks(self):
cb = StartProcOrThread(self.thread)
cb.chief_only = False
return [cb]
def get_input_tensors(self):
with tf.device('/cpu:0'):
......@@ -461,13 +465,13 @@ class StagingInputWrapper(FeedfreeInput):
self._input.setup(inputs)
self.setup_staging_areas()
def setup_training(self, trainer):
self._input.setup_training(trainer)
self.setup_staging_areas()
def get_callbacks(self):
cbs = self._input.get_callbacks()
trainer.register_callback(
cbs.append(
StagingInputWrapper.StagingCallback(
self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
return cbs
def setup_staging_areas(self):
logger.info("Setting up StagingArea for GPU prefetching ...")
......@@ -531,10 +535,8 @@ class ReorderInputSource(FeedfreeInput):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup(inputs)
def setup_training(self, trainer):
inputs = trainer.model.get_inputs_desc()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup_training(trainer)
def get_callbacks(self):
return self._input.get_callbacks()
def reset_state(self):
self._input.reset_state()
......
......@@ -35,8 +35,10 @@ class SimpleTrainer(Trainer):
self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self):
self._input_source.setup_training(self)
model = self.model
self._input_source.setup(model.get_inputs_desc())
cbs = self._input_source.get_callbacks()
assert len(cbs) == 0, "Feedinput has no callbacks!"
self.inputs = self._input_source.get_input_tensors()
with TowerContext('', is_training=True):
model.build_graph(self.inputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment