Commit 4692e325 authored by Yuxin Wu's avatar Yuxin Wu

Let StagingInput figure out the device by itself.

parent 46991853
...@@ -141,7 +141,7 @@ class MultiGPUGANTrainer(TowerTrainer): ...@@ -141,7 +141,7 @@ class MultiGPUGANTrainer(TowerTrainer):
raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)] raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)]
# Setup input # Setup input
input = StagingInput(input, list(range(nr_gpu))) input = StagingInput(input)
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs) self.register_callback(cbs)
......
...@@ -21,7 +21,7 @@ class ModelSaver(Callback): ...@@ -21,7 +21,7 @@ class ModelSaver(Callback):
def __init__(self, max_to_keep=10, def __init__(self, max_to_keep=10,
keep_checkpoint_every_n_hours=0.5, keep_checkpoint_every_n_hours=0.5,
checkpoint_dir=None, checkpoint_dir=None,
var_collections=tf.GraphKeys.GLOBAL_VARIABLES): var_collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES]):
""" """
Args: Args:
max_to_keep (int): the same as in ``tf.train.Saver``. max_to_keep (int): the same as in ``tf.train.Saver``.
......
...@@ -483,13 +483,18 @@ class StagingInput(FeedfreeInput): ...@@ -483,13 +483,18 @@ class StagingInput(FeedfreeInput):
A callback registered by this input source, to make sure stage/unstage A callback registered by this input source, to make sure stage/unstage
is run at each step. is run at each step.
""" """
def __init__(self, stage_op, unstage_op, nr_stage): def __init__(self, stage_op_fn, unstage_op_fn, nr_stage):
self.nr_stage = nr_stage self.nr_stage = nr_stage
self.stage_op = stage_op self.stage_op_fn = stage_op_fn
self.fetches = tf.train.SessionRunArgs( self.unstage_op_fn = unstage_op_fn
fetches=[stage_op, unstage_op])
self._initialized = False self._initialized = False
def _setup_graph(self):
self.stage_op = self.stage_op_fn()
unstage_op = self.unstage_op_fn()
self.fetches = tf.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op])
def _prefill(self): def _prefill(self):
logger.info("Pre-filling staging area ...") logger.info("Pre-filling staging area ...")
for k in range(self.nr_stage): for k in range(self.nr_stage):
...@@ -502,21 +507,17 @@ class StagingInput(FeedfreeInput): ...@@ -502,21 +507,17 @@ class StagingInput(FeedfreeInput):
self._prefill() self._prefill()
return self.fetches return self.fetches
def __init__(self, input, towers, nr_stage=5): def __init__(self, input, towers=None, nr_stage=5):
""" """
Args: Args:
input (FeedfreeInput): input (FeedfreeInput):
towers ([int]): list of GPU ids to prefetch on.
nr_stage: number of elements to prefetch on each GPU. nr_stage: number of elements to prefetch on each GPU.
towers: deprecated
""" """
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
self._input = input self._input = input
if not isinstance(towers[0], int): if towers is not None:
# API changed log_deprecated("StagingInput(towers=) has no effect! Devices are handled automatically.")
log_deprecated("StagingInput(devices=)", "Use (towers=) instead!", "2018-01-31")
self._devices = towers
else:
self._devices = ['/gpu:{}'.format(k) for k in towers]
self._nr_stage = nr_stage self._nr_stage = nr_stage
self._areas = [] self._areas = []
...@@ -525,48 +526,41 @@ class StagingInput(FeedfreeInput): ...@@ -525,48 +526,41 @@ class StagingInput(FeedfreeInput):
def _setup(self, inputs): def _setup(self, inputs):
self._input.setup(inputs) self._input.setup(inputs)
self._setup_staging_areas()
def _get_callbacks(self): def _get_callbacks(self):
cbs = self._input.get_callbacks() cbs = self._input.get_callbacks()
# Pass a lambda to be called later, because stage ops have not been built
cbs.append( cbs.append(
StagingInput.StagingCallback( StagingInput.StagingCallback(
self._get_stage_op(), self._get_unstage_op(), self._nr_stage)) lambda: self._get_stage_op(), lambda: self._get_unstage_op(), self._nr_stage))
return cbs return cbs
def _setup_staging_areas(self):
logger.info("Setting up StagingArea for GPU prefetching ...")
with self.cached_name_scope():
for idx, device in enumerate(self._devices):
with tf.device(device):
inputs = self._input.get_input_tensors()
# Putting variables to stagingarea will cause trouble
dtypes = []
for idx in range(len(inputs)):
dtype = inputs[idx].dtype
if dtype.base_dtype != dtype: # is reference type
inputs[idx] = tf.identity(inputs[idx])
dtypes.append(dtype.base_dtype)
stage = StagingArea(dtypes, shapes=None)
self._stage_ops.append(stage.put(inputs))
self._areas.append(stage)
outputs = stage.get()
if isinstance(outputs, tf.Tensor): # when size=1, TF doesn't return a list
outputs = [outputs]
for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs)
def _size(self): def _size(self):
return self._input.size() return self._input.size()
def _get_input_tensors(self): def _get_input_tensors(self):
ctx = get_current_tower_context() with self.cached_name_scope():
ret = self._unstage_ops[ctx.index] inputs = self._input.get_input_tensors()
return ret
# Putting variables to stagingarea will cause trouble
dtypes = []
for idx in range(len(inputs)):
dtype = inputs[idx].dtype
if dtype.base_dtype != dtype: # is reference type
inputs[idx] = tf.identity(inputs[idx])
dtypes.append(dtype.base_dtype)
stage = StagingArea(dtypes, shapes=None)
self._stage_ops.append(stage.put(inputs))
self._areas.append(stage)
outputs = stage.get()
if isinstance(outputs, tf.Tensor): # when size=1, TF doesn't return a list
outputs = [outputs]
for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs)
return outputs
def _get_stage_op(self): def _get_stage_op(self):
with self.cached_name_scope(): with self.cached_name_scope():
......
...@@ -136,7 +136,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -136,7 +136,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
assert get_tf_version_number() >= 1.4, \ assert get_tf_version_number() >= 1.4, \
"Fine tuning a BatchNorm model with fixed statistics is only " \ "Fine tuning a BatchNorm model with fixed statistics is only " \
"supported after https://github.com/tensorflow/tensorflow/pull/12580 " "supported after https://github.com/tensorflow/tensorflow/pull/12580 "
if ctx.index == 0: # only warn in first tower if ctx.is_main_training_tower: # only warn in first tower
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.") logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# Using moving_mean/moving_variance in training, which means we # Using moving_mean/moving_variance in training, which means we
# loaded a pre-trained BN and only fine-tuning the affine part. # loaded a pre-trained BN and only fine-tuning the affine part.
......
...@@ -87,6 +87,7 @@ class TowerContext(object): ...@@ -87,6 +87,7 @@ class TowerContext(object):
""" """
return self._collection_guard.get_collection_in_tower(key) return self._collection_guard.get_collection_in_tower(key)
# TODO currently only used in StagingInput
@property @property
def index(self): def index(self):
return self._index return self._index
......
...@@ -39,7 +39,7 @@ def apply_default_prefetch(input_source_or_dataflow, trainer): ...@@ -39,7 +39,7 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
if not isinstance(input, (StagingInput, DummyConstantInput)): if not isinstance(input, (StagingInput, DummyConstantInput)):
input = StagingInput(input, towers) input = StagingInput(input)
return input return input
......
...@@ -44,7 +44,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True): ...@@ -44,7 +44,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
# seem to only improve on >1 GPUs # seem to only improve on >1 GPUs
if not isinstance(config.data, (StagingInput, DummyConstantInput)): if not isinstance(config.data, (StagingInput, DummyConstantInput)):
config.data = StagingInput(config.data, config.tower) config.data = StagingInput(config.data)
class SyncMultiGPUTrainerParameterServer(Trainer): class SyncMultiGPUTrainerParameterServer(Trainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment