Commit 4692e325 authored by Yuxin Wu's avatar Yuxin Wu

Let StagingInput figure out the device by itself.

parent 46991853
......@@ -141,7 +141,7 @@ class MultiGPUGANTrainer(TowerTrainer):
raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)]
# Setup input
input = StagingInput(input, list(range(nr_gpu)))
input = StagingInput(input)
cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs)
......
......@@ -21,7 +21,7 @@ class ModelSaver(Callback):
def __init__(self, max_to_keep=10,
keep_checkpoint_every_n_hours=0.5,
checkpoint_dir=None,
var_collections=tf.GraphKeys.GLOBAL_VARIABLES):
var_collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES]):
"""
Args:
max_to_keep (int): the same as in ``tf.train.Saver``.
......
......@@ -483,13 +483,18 @@ class StagingInput(FeedfreeInput):
A callback registered by this input source, to make sure stage/unstage
is run at each step.
"""
def __init__(self, stage_op, unstage_op, nr_stage):
def __init__(self, stage_op_fn, unstage_op_fn, nr_stage):
self.nr_stage = nr_stage
self.stage_op = stage_op
self.fetches = tf.train.SessionRunArgs(
fetches=[stage_op, unstage_op])
self.stage_op_fn = stage_op_fn
self.unstage_op_fn = unstage_op_fn
self._initialized = False
def _setup_graph(self):
self.stage_op = self.stage_op_fn()
unstage_op = self.unstage_op_fn()
self.fetches = tf.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op])
def _prefill(self):
logger.info("Pre-filling staging area ...")
for k in range(self.nr_stage):
......@@ -502,21 +507,17 @@ class StagingInput(FeedfreeInput):
self._prefill()
return self.fetches
def __init__(self, input, towers, nr_stage=5):
def __init__(self, input, towers=None, nr_stage=5):
"""
Args:
input (FeedfreeInput):
towers ([int]): list of GPU ids to prefetch on.
nr_stage: number of elements to prefetch on each GPU.
towers: deprecated
"""
assert isinstance(input, FeedfreeInput), input
self._input = input
if not isinstance(towers[0], int):
# API changed
log_deprecated("StagingInput(devices=)", "Use (towers=) instead!", "2018-01-31")
self._devices = towers
else:
self._devices = ['/gpu:{}'.format(k) for k in towers]
if towers is not None:
log_deprecated("StagingInput(towers=) has no effect! Devices are handled automatically.")
self._nr_stage = nr_stage
self._areas = []
......@@ -525,21 +526,21 @@ class StagingInput(FeedfreeInput):
def _setup(self, inputs):
self._input.setup(inputs)
self._setup_staging_areas()
def _get_callbacks(self):
cbs = self._input.get_callbacks()
# Pass a lambda to be called later, because stage ops have not been built
cbs.append(
StagingInput.StagingCallback(
self._get_stage_op(), self._get_unstage_op(), self._nr_stage))
lambda: self._get_stage_op(), lambda: self._get_unstage_op(), self._nr_stage))
return cbs
def _setup_staging_areas(self):
logger.info("Setting up StagingArea for GPU prefetching ...")
def _size(self):
return self._input.size()
def _get_input_tensors(self):
with self.cached_name_scope():
for idx, device in enumerate(self._devices):
with tf.device(device):
inputs = self._input.get_input_tensors()
# Putting variables to stagingarea will cause trouble
......@@ -559,14 +560,7 @@ class StagingInput(FeedfreeInput):
for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs)
def _size(self):
return self._input.size()
def _get_input_tensors(self):
ctx = get_current_tower_context()
ret = self._unstage_ops[ctx.index]
return ret
return outputs
def _get_stage_op(self):
with self.cached_name_scope():
......
......@@ -136,7 +136,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
assert get_tf_version_number() >= 1.4, \
"Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.index == 0: # only warn in first tower
if ctx.is_main_training_tower: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part.
......
......@@ -87,6 +87,7 @@ class TowerContext(object):
"""
return self._collection_guard.get_collection_in_tower(key)
# TODO currently only used in StagingInput
@property
def index(self):
return self._index
......
......@@ -39,7 +39,7 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
assert tf.test.is_gpu_available()
if not isinstance(input, (StagingInput, DummyConstantInput)):
input = StagingInput(input, towers)
input = StagingInput(input)
return input
......
......@@ -44,7 +44,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
# seem to only improve on >1 GPUs
if not isinstance(config.data, (StagingInput, DummyConstantInput)):
config.data = StagingInput(config.data, config.tower)
config.data = StagingInput(config.data)
class SyncMultiGPUTrainerParameterServer(Trainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment