Commit 3e9de2ae authored by Yuxin Wu's avatar Yuxin Wu

a faster DummyInput

parent 2b4f7b14
......@@ -27,7 +27,7 @@ import tensorflow as tf
def _read_words(filename):
with tf.gfile.GFile(filename, "r") as f:
with tf.gfile.GFile(filename, "rb") as f:
return f.read().decode("utf-8").replace("\n", "<eos>").split()
......
......@@ -42,6 +42,12 @@ class TowerContext(object):
def name(self):
return self._name
@property
def index(self):
if self._name == '':
return 0
return int(self._name[-1])
def get_variable_on_tower(self, *args, **kwargs):
"""
Get a variable for this tower specifically, without reusing, even if
......
......@@ -254,18 +254,35 @@ class DummyConstantInput(FeedfreeInput):
"""
self.shapes = shapes
logger.warn("Using dummy input for debug!")
self._cnt = 0
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
def get_input_tensors(self):
def setup_training(self, trainer):
super(DummyConstantInput, self).setup_training(trainer)
nr_tower = trainer.config.nr_tower
placehdrs = self.input_placehdrs
assert len(self.shapes) == len(placehdrs)
ret = []
for idx, p in enumerate(placehdrs):
ret.append(tf.get_variable(
'dummy-' + p.op.name, shape=self.shapes[idx],
dtype=p.dtype, trainable=False))
self.tensors = []
# don't share variables
for tower in range(nr_tower):
tlist = []
# TODO. keep device info in tower
with tf.device('/gpu:{}'.format(tower)):
for idx, p in enumerate(placehdrs):
tlist.append(tf.get_variable(
'dummy-{}-{}'.format(p.op.name, tower), shape=self.shapes[idx],
dtype=p.dtype, trainable=False))
self.tensors.append(tlist)
def get_input_tensors(self):
# TODO XXX call with tower index
ret = self.tensors[self._cnt]
self._cnt += 1
return ret
......@@ -318,12 +335,10 @@ class ZMQInput(FeedfreeInput):
class StagingInputWrapper(FeedfreeInput):
class StagingCallback(Callback):
def __init__(self, stage_op, unstage_op, nr_stage):
self.nr_stage = nr_stage
self.stage_op = stage_op
# TODO make sure both stage/unstage are run, to avoid OOM
self.fetches = tf.train.SessionRunArgs(
fetches=[stage_op, unstage_op])
......@@ -335,13 +350,15 @@ class StagingInputWrapper(FeedfreeInput):
def _before_run(self, ctx):
return self.fetches
def __init__(self, input, devices):
def __init__(self, input, devices, nr_stage=5):
self._input = input
assert isinstance(input, FeedfreeInput)
self._devices = devices
self._nr_stage = nr_stage
self._areas = []
self._stage_ops = []
self._unstage_ops = []
self._cnt_unstage = 0
def setup(self, model):
......@@ -354,7 +371,7 @@ class StagingInputWrapper(FeedfreeInput):
trainer.register_callback(
StagingInputWrapper.StagingCallback(
self.get_stage_op(), self.get_unstage_op(), 5))
self.get_stage_op(), self.get_unstage_op(), self._nr_stage))
def setup_staging_areas(self):
for idx, device in enumerate(self._devices):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment