Commit 0be066fe authored by Yuxin Wu's avatar Yuxin Wu

small fix in staging

parent 3f48ed30
...@@ -24,8 +24,8 @@ class FeedfreeTrainerBase(Trainer): ...@@ -24,8 +24,8 @@ class FeedfreeTrainerBase(Trainer):
Get input tensors from `self.input_method` and build the graph. Get input tensors from `self.input_method` and build the graph.
""" """
def f(): def f():
inputs = self._input_method.get_input_tensors() self._input_tensors = self._input_method.get_input_tensors()
self.model.build_graph(inputs) self.model.build_graph(self._input_tensors)
ctx = get_current_tower_context() ctx = get_current_tower_context()
if ctx is None: if ctx is None:
with TowerContext(''): with TowerContext(''):
...@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer): ...@@ -98,7 +98,7 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
cost, grads = self._get_cost_and_grad() cost, grads = self._get_cost_and_grad()
self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op') self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op')
# skip training # skip training
# self.train_op = tf.group(*self.dequed_inputs) # self.train_op = tf.group(*self._input_tensors)
def QueueInputTrainer(config, input_queue=None, predict_tower=None): def QueueInputTrainer(config, input_queue=None, predict_tower=None):
...@@ -117,9 +117,9 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None): ...@@ -117,9 +117,9 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
else: else:
assert isinstance(config.data, QueueInput), config.data assert isinstance(config.data, QueueInput), config.data
# from tensorpack.train.input_data import QueueInput, FeedfreeInput, StagingInputWrapper, DummyConstantInput # from tensorpack.train.input_data import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0']) # config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[64,224,224,3], [64]]) # config.data = DummyConstantInput([[128,224,224,3], [128]])
if predict_tower is not None: if predict_tower is not None:
log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!") log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!")
......
...@@ -169,7 +169,6 @@ class QueueInput(FeedfreeInput): ...@@ -169,7 +169,6 @@ class QueueInput(FeedfreeInput):
def get_input_tensors(self): def get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
#ret[0]= tf.Print(ret[0], [tf.reduce_mean(ret[0])], "asdf")
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
assert len(ret) == len(self.input_placehdrs) assert len(ret) == len(self.input_placehdrs)
...@@ -326,7 +325,7 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -326,7 +325,7 @@ class StagingInputWrapper(FeedfreeInput):
self.stage_op = stage_op self.stage_op = stage_op
# TODO make sure both stage/unstage are run, to avoid OOM # TODO make sure both stage/unstage are run, to avoid OOM
self.fetches = tf.train.SessionRunArgs( self.fetches = tf.train.SessionRunArgs(
fetches=[stage_op]) fetches=[stage_op, unstage_op])
def _before_train(self): def _before_train(self):
# pre-fill the staging area # pre-fill the staging area
...@@ -350,8 +349,8 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -350,8 +349,8 @@ class StagingInputWrapper(FeedfreeInput):
self.setup_staging_areas() self.setup_staging_areas()
def setup_training(self, trainer): def setup_training(self, trainer):
super(StagingInputWrapper, self).setup_training(trainer)
self._input.setup_training(trainer) self._input.setup_training(trainer)
self.setup_staging_areas()
trainer.register_callback( trainer.register_callback(
StagingInputWrapper.StagingCallback( StagingInputWrapper.StagingCallback(
...@@ -359,11 +358,10 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -359,11 +358,10 @@ class StagingInputWrapper(FeedfreeInput):
def setup_staging_areas(self): def setup_staging_areas(self):
for idx, device in enumerate(self._devices): for idx, device in enumerate(self._devices):
inputs = self._input.get_input_tensors()
dtypes = [x.dtype for x in inputs]
with tf.device(device): with tf.device(device):
stage = StagingArea( inputs = self._input.get_input_tensors()
dtypes, shapes=None) dtypes = [x.dtype for x in inputs]
stage = StagingArea(dtypes, shapes=None)
self._stage_ops.append(stage.put(inputs)) self._stage_ops.append(stage.put(inputs))
self._areas.append(stage) self._areas.append(stage)
outputs = stage.get() outputs = stage.get()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment