Commit 14c564cc authored by Yuxin Wu's avatar Yuxin Wu

Pass InputDesc instead of ModelDesc to InputSource.setup()

parent 8b879cb9
...@@ -83,7 +83,7 @@ class InferenceRunnerBase(Callback): ...@@ -83,7 +83,7 @@ class InferenceRunnerBase(Callback):
self._extra_hooks = extra_hooks self._extra_hooks = extra_hooks
def _setup_graph(self): def _setup_graph(self):
self._input_source.setup(self.trainer.model) self._input_source.setup(self.trainer.model.get_inputs_desc())
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0] self._predict_tower_id = self.trainer.config.predict_tower[0]
...@@ -188,7 +188,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -188,7 +188,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _setup_graph(self): def _setup_graph(self):
model = self.trainer.model model = self.trainer.model
self._input_source.setup(model) self._input_source.setup(model.get_inputs_desc())
# build graph # build graph
def build_tower(k): def build_tower(k):
......
...@@ -46,11 +46,15 @@ class InputSource(object): ...@@ -46,11 +46,15 @@ class InputSource(object):
For non-placeholder tensors, should always create and return new tensors when called. For non-placeholder tensors, should always create and return new tensors when called.
""" """
def setup(self, model): def setup(self, inputs_desc):
"""
Args:
inputs_desc (list[InputDesc]): list of input desc
"""
pass pass
def setup_training(self, trainer): def setup_training(self, trainer):
self.setup(trainer.model) self.setup(trainer.model.get_inputs_desc())
@abstractmethod @abstractmethod
def reset_state(self): def reset_state(self):
...@@ -82,8 +86,7 @@ class FeedInput(InputSource): ...@@ -82,8 +86,7 @@ class FeedInput(InputSource):
def size(self): def size(self):
return self.ds.size() return self.ds.size()
def setup(self, model): def setup(self, inputs):
inputs = model.get_inputs_desc()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._input_names is None: if self._input_names is None:
self._placehdrs_to_feed = self._all_placehdrs self._placehdrs_to_feed = self._all_placehdrs
...@@ -115,8 +118,7 @@ class DataParallelFeedInput(FeedInput): ...@@ -115,8 +118,7 @@ class DataParallelFeedInput(FeedInput):
self._tower_names = tower_names self._tower_names = tower_names
self._nr_tower = len(tower_names) self._nr_tower = len(tower_names)
def setup(self, model): def setup(self, inputs):
inputs = model.get_inputs_desc()
self._placehdrs_per_tower = [] self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = [] self._feed_placehdrs_per_tower = []
for tname in self._tower_names: for tname in self._tower_names:
...@@ -231,9 +233,8 @@ class QueueInput(FeedfreeInput): ...@@ -231,9 +233,8 @@ class QueueInput(FeedfreeInput):
return self.ds.size() return self.ds.size()
# TODO use input data mapping. not all placeholders are needed # TODO use input data mapping. not all placeholders are needed
def setup(self, model): def setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
inputs = model.get_inputs_desc()
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
if self._names is None: if self._names is None:
self._queue_feedpoint = self.input_placehdrs self._queue_feedpoint = self.input_placehdrs
...@@ -289,9 +290,8 @@ class BatchQueueInput(FeedfreeInput): ...@@ -289,9 +290,8 @@ class BatchQueueInput(FeedfreeInput):
def size(self): def size(self):
return self.ds.size() // self.batch_size return self.ds.size() // self.batch_size
def setup(self, model): def setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
inputs = model.get_inputs_desc()
self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self.input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"BatchQueueInput has to be used with some InputDesc!" "BatchQueueInput has to be used with some InputDesc!"
...@@ -389,8 +389,8 @@ class DummyConstantInput(TensorInput): ...@@ -389,8 +389,8 @@ class DummyConstantInput(TensorInput):
return tlist return tlist
super(DummyConstantInput, self).__init__(fn) super(DummyConstantInput, self).__init__(fn)
def setup(self, model): def setup(self, inputs):
self.inputs_desc = model.get_inputs_desc() self.inputs_desc = inputs
# TODO doesn't support remapping # TODO doesn't support remapping
...@@ -413,8 +413,8 @@ class ZMQInput(TensorInput): ...@@ -413,8 +413,8 @@ class ZMQInput(TensorInput):
return ret return ret
super(ZMQInput, self).__init__(fn) super(ZMQInput, self).__init__(fn)
def setup(self, model): def setup(self, inputs):
self.inputs_desc = model.get_inputs_desc() self.inputs_desc = inputs
assert len(self.inputs_desc) > 0, \ assert len(self.inputs_desc) > 0, \
"ZMQInput has to be used with InputDesc!" "ZMQInput has to be used with InputDesc!"
...@@ -457,8 +457,8 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -457,8 +457,8 @@ class StagingInputWrapper(FeedfreeInput):
self._stage_ops = [] self._stage_ops = []
self._unstage_ops = [] self._unstage_ops = []
def setup(self, model): def setup(self, inputs):
self._input.setup(model) self._input.setup(inputs)
self.setup_staging_areas() self.setup_staging_areas()
def setup_training(self, trainer): def setup_training(self, trainer):
...@@ -527,10 +527,9 @@ class ReorderInputSource(FeedfreeInput): ...@@ -527,10 +527,9 @@ class ReorderInputSource(FeedfreeInput):
def size(self): def size(self):
return self._input.size() return self._input.size()
def setup(self, model): def setup(self, inputs):
inputs = model.get_inputs_desc()
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._input.setup(model) self._input.setup(inputs)
def setup_training(self, trainer): def setup_training(self, trainer):
inputs = trainer.model.get_inputs_desc() inputs = trainer.model.get_inputs_desc()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment