Commit 14c564cc authored by Yuxin Wu's avatar Yuxin Wu

Pass InputDesc instead of ModelDesc to InputSource.setup()

parent 8b879cb9
......@@ -83,7 +83,7 @@ class InferenceRunnerBase(Callback):
self._extra_hooks = extra_hooks
def _setup_graph(self):
self._input_source.setup(self.trainer.model)
self._input_source.setup(self.trainer.model.get_inputs_desc())
# Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0]
......@@ -188,7 +188,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _setup_graph(self):
model = self.trainer.model
self._input_source.setup(model)
self._input_source.setup(model.get_inputs_desc())
# build graph
def build_tower(k):
......
......@@ -46,11 +46,15 @@ class InputSource(object):
For non-placeholder tensors, should always create and return new tensors when called.
"""
def setup(self, model):
def setup(self, inputs_desc):
"""
Args:
inputs_desc (list[InputDesc]): list of input desc
"""
pass
def setup_training(self, trainer):
self.setup(trainer.model)
self.setup(trainer.model.get_inputs_desc())
@abstractmethod
def reset_state(self):
......@@ -82,8 +86,7 @@ class FeedInput(InputSource):
def size(self):
return self.ds.size()
def setup(self, model):
inputs = model.get_inputs_desc()
def setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._input_names is None:
self._placehdrs_to_feed = self._all_placehdrs
......@@ -115,8 +118,7 @@ class DataParallelFeedInput(FeedInput):
self._tower_names = tower_names
self._nr_tower = len(tower_names)
def setup(self, model):
inputs = model.get_inputs_desc()
def setup(self, inputs):
self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = []
for tname in self._tower_names:
......@@ -231,9 +233,8 @@ class QueueInput(FeedfreeInput):
return self.ds.size()
# TODO use input data mapping. not all placeholders are needed
def setup(self, model):
def setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...")
inputs = model.get_inputs_desc()
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._names is None:
self._queue_feedpoint = self.input_placehdrs
......@@ -289,9 +290,8 @@ class BatchQueueInput(FeedfreeInput):
def size(self):
return self.ds.size() // self.batch_size
def setup(self, model):
def setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...")
inputs = model.get_inputs_desc()
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with some InputDesc!"
......@@ -389,8 +389,8 @@ class DummyConstantInput(TensorInput):
return tlist
super(DummyConstantInput, self).__init__(fn)
def setup(self, model):
self.inputs_desc = model.get_inputs_desc()
def setup(self, inputs):
self.inputs_desc = inputs
# TODO doesn't support remapping
......@@ -413,8 +413,8 @@ class ZMQInput(TensorInput):
return ret
super(ZMQInput, self).__init__(fn)
def setup(self, model):
self.inputs_desc = model.get_inputs_desc()
def setup(self, inputs):
self.inputs_desc = inputs
assert len(self.inputs_desc) > 0, \
"ZMQInput has to be used with InputDesc!"
......@@ -457,8 +457,8 @@ class StagingInputWrapper(FeedfreeInput):
self._stage_ops = []
self._unstage_ops = []
def setup(self, model):
self._input.setup(model)
def setup(self, inputs):
self._input.setup(inputs)
self.setup_staging_areas()
def setup_training(self, trainer):
......@@ -527,10 +527,9 @@ class ReorderInputSource(FeedfreeInput):
def size(self):
return self._input.size()
def setup(self, model):
inputs = model.get_inputs_desc()
def setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup(model)
self._input.setup(inputs)
def setup_training(self, trainer):
inputs = trainer.model.get_inputs_desc()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment