Commit e2f9798a authored by Yuxin Wu's avatar Yuxin Wu

replace setup_training by get_callbacks in InputSource

parent 14c564cc
...@@ -84,6 +84,8 @@ class InferenceRunnerBase(Callback): ...@@ -84,6 +84,8 @@ class InferenceRunnerBase(Callback):
def _setup_graph(self): def _setup_graph(self):
self._input_source.setup(self.trainer.model.get_inputs_desc()) self._input_source.setup(self.trainer.model.get_inputs_desc())
assert len(self._input_source.get_callbacks()) == 0, \
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0] self._predict_tower_id = self.trainer.config.predict_tower[0]
...@@ -189,6 +191,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -189,6 +191,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _setup_graph(self): def _setup_graph(self):
model = self.trainer.model model = self.trainer.model
self._input_source.setup(model.get_inputs_desc()) self._input_source.setup(model.get_inputs_desc())
assert len(self._input_source.get_callbacks()) == 0, \
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# build graph # build graph
def build_tower(k): def build_tower(k):
......
...@@ -37,7 +37,10 @@ class FeedfreeTrainerBase(Trainer): ...@@ -37,7 +37,10 @@ class FeedfreeTrainerBase(Trainer):
def _setup(self): def _setup(self):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source) assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_source.setup_training(self) self._input_source.setup(self.model.get_inputs_desc())
input_callbacks = self._input_source.get_callbacks()
for cb in input_callbacks:
self.register_callback(cb)
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``.""" """ Simply run ``self.train_op``."""
......
...@@ -53,8 +53,12 @@ class InputSource(object): ...@@ -53,8 +53,12 @@ class InputSource(object):
""" """
pass pass
def setup_training(self, trainer): def get_callbacks(self):
self.setup(trainer.model.get_inputs_desc()) """
Returns:
list[Callback]: extra callbacks required by this InputSource.
"""
return []
@abstractmethod @abstractmethod
def reset_state(self): def reset_state(self):
...@@ -248,11 +252,10 @@ class QueueInput(FeedfreeInput): ...@@ -248,11 +252,10 @@ class QueueInput(FeedfreeInput):
name='input_queue') name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self._queue_feedpoint) self.thread = EnqueueThread(self.queue, self.ds, self._queue_feedpoint)
def setup_training(self, trainer): def get_callbacks(self):
super(QueueInput, self).setup_training(trainer)
cb = StartProcOrThread(self.thread) cb = StartProcOrThread(self.thread)
cb.chief_only = False cb.chief_only = False
trainer.register_callback(cb) return [cb]
def get_input_tensors(self): def get_input_tensors(self):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
...@@ -321,9 +324,10 @@ class BatchQueueInput(FeedfreeInput): ...@@ -321,9 +324,10 @@ class BatchQueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch) self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def setup_training(self, trainer): def get_callbacks(self):
super(BatchQueueInput, self).setup_training(trainer) cb = StartProcOrThread(self.thread)
trainer.register_callback(StartProcOrThread(self.thread)) cb.chief_only = False
return [cb]
def get_input_tensors(self): def get_input_tensors(self):
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
...@@ -461,13 +465,13 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -461,13 +465,13 @@ class StagingInputWrapper(FeedfreeInput):
self._input.setup(inputs) self._input.setup(inputs)
self.setup_staging_areas() self.setup_staging_areas()
def setup_training(self, trainer): def get_callbacks(self):
self._input.setup_training(trainer) cbs = self._input.get_callbacks()
self.setup_staging_areas()
trainer.register_callback( cbs.append(
StagingInputWrapper.StagingCallback( StagingInputWrapper.StagingCallback(
self.get_stage_op(), self.get_unstage_op(), self._nr_stage)) self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
return cbs
def setup_staging_areas(self): def setup_staging_areas(self):
logger.info("Setting up StagingArea for GPU prefetching ...") logger.info("Setting up StagingArea for GPU prefetching ...")
...@@ -531,10 +535,8 @@ class ReorderInputSource(FeedfreeInput): ...@@ -531,10 +535,8 @@ class ReorderInputSource(FeedfreeInput):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup(inputs) self._input.setup(inputs)
def setup_training(self, trainer): def get_callbacks(self):
inputs = trainer.model.get_inputs_desc() return self._input.get_callbacks()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup_training(trainer)
def reset_state(self): def reset_state(self):
self._input.reset_state() self._input.reset_state()
......
...@@ -35,8 +35,10 @@ class SimpleTrainer(Trainer): ...@@ -35,8 +35,10 @@ class SimpleTrainer(Trainer):
self.hooked_sess.run(self.train_op, feed_dict=feed) self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self): def _setup(self):
self._input_source.setup_training(self)
model = self.model model = self.model
self._input_source.setup(model.get_inputs_desc())
cbs = self._input_source.get_callbacks()
assert len(cbs) == 0, "Feedinput has no callbacks!"
self.inputs = self._input_source.get_input_tensors() self.inputs = self._input_source.get_input_tensors()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(self.inputs) model.build_graph(self.inputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment