Commit 9d0b28a0 authored by Yuxin Wu's avatar Yuxin Wu

Add staging input (#140). Didn't see improvement

parent dabebf69
......@@ -117,6 +117,10 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
else:
assert isinstance(config.data, QueueInput), config.data
# from tensorpack.train.input_data import QueueInput, FeedfreeInput, StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[64,224,224,3], [64]])
if predict_tower is not None:
log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!")
config.predict_tower = predict_tower
......
......@@ -4,20 +4,26 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from tensorflow.contrib.staging import StagingArea
from itertools import chain
from abc import ABCMeta, abstractmethod
import six
from six.moves import range
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
from ..utils import logger
from ..utils.argtools import memoized
from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback
__all__ = ['InputData', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'ZMQInput',
'DummyConstantInput', 'TensorInput']
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper']
@six.add_metaclass(ABCMeta)
......@@ -160,6 +166,7 @@ class QueueInput(FeedfreeInput):
def get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque')
#ret[0]= tf.Print(ret[0], [tf.reduce_mean(ret[0])], "asdf")
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
......@@ -306,3 +313,80 @@ class ZMQInput(FeedfreeInput):
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
return ret
class StagingInputWrapper(FeedfreeInput):
class StagingCallback(Callback):
def __init__(self, stage_op, unstage_op, nr_stage):
self.nr_stage = nr_stage
self.stage_op = stage_op
# TODO make sure both stage/unstage are run, to avoid OOM
self.fetches = tf.train.SessionRunArgs(
fetches=[stage_op])
def _before_train(self):
# pre-fill the staging area
for k in range(self.nr_stage):
self.stage_op.run()
def _before_run(self, ctx):
return self.fetches
def __init__(self, input, devices):
self._input = input
assert isinstance(input, FeedfreeInput)
self._devices = devices
self._areas = []
self._stage_ops = []
self._unstage_ops = []
self._cnt_unstage = 0
def setup(self, model):
self._input.setup(model)
self.setup_staging_areas()
def setup_training(self, trainer):
super(StagingInputWrapper, self).setup_training(trainer)
self._input.setup_training(trainer)
trainer.register_callback(
StagingInputWrapper.StagingCallback(
self.get_stage_op(), self.get_unstage_op(), 5))
def setup_staging_areas(self):
for idx, device in enumerate(self._devices):
inputs = self._input.get_input_tensors()
dtypes = [x.dtype for x in inputs]
with tf.device(device):
stage = StagingArea(
dtypes, shapes=None)
self._stage_ops.append(stage.put(inputs))
self._areas.append(stage)
outputs = stage.get()
for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs)
def size(self):
return self._input.size()
def get_input_tensors(self):
assert self._cnt_unstage < len(self._areas)
assert len(self._areas) == len(self._devices)
ret = self._unstage_ops[self._cnt_unstage]
self._cnt_unstage += 1
return ret
@staticmethod
def get_staging_name(idx):
return 'StagingArea{}'.format(idx)
@memoized
def get_stage_op(self):
return tf.group(*self._stage_ops)
@memoized
def get_unstage_op(self):
all_outputs = list(chain.from_iterable(self._unstage_ops))
return tf.group(*all_outputs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment