Commit ee1af311 authored by Yuxin Wu's avatar Yuxin Wu

comments & fix lint

parent f1e3b3ae
......@@ -6,7 +6,6 @@ import tensorflow as tf
from six.moves import range
from ..utils import logger
from .input_source import StagingInputWrapper
from .feedfree import SingleCostFeedfreeTrainer
from .multigpu import MultiGPUTrainerBase
from ..callbacks import RunOp
......@@ -19,18 +18,14 @@ __all__ = ['DistributedReplicatedTrainer']
PS_SHADOW_VAR_PREFIX = 'ps_var'
# To be used with custom_getter on tf.get_variable. Ensures the created variable
# is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
class OverrideToLocalVariableIfNotPsVar(object):
# args and kwargs come from the custom_getter interface for Tensorflow
# variables, and matches tf.get_variable's signature, with the addition of
# 'getter' at the beginning.
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def __call__(self, getter, name, *args, **kwargs):
if name.startswith(PS_SHADOW_VAR_PREFIX):
return getter(*args, **kwargs)
logger.info("CustomGetter-{}".format(name))
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
......@@ -50,7 +45,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.job_name = server_def.job_name
self.task_index = server_def.task_index
assert self.job_name in ['ps', 'worker'], job_name
assert self.job_name in ['ps', 'worker'], self.job_name
assert tf.test.is_gpu_available
self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
......@@ -71,14 +67,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0
if self.nr_gpu > 1:
assert tf.test.is_gpu_available()
# TODO staging doesn't work with dummy (require context)
# seem to only improve on >1 GPUs
#if not isinstance(self._input_source, StagingInputWrapper):
#self._input_source = StagingInputWrapper(self._input_source, self.raw_devices)
@staticmethod
def _average_grads(tower_grads, devices):
"""
......@@ -134,7 +122,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
list of copy ops
"""
# TODO do this for each variable separately?
opt = self.model.get_optimizer() # TODO ensure it in global scope, not local
opt = self.model.get_optimizer()
var_update_ops = []
for vid, (g, v) in enumerate(ps_var_grads):
apply_gradient_op = opt.apply_gradients([(g, v)])
......@@ -175,7 +163,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
var_update_ops = self._apply_gradients_and_copy(grad_list, ps_var_grads)
main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [main_fetch])
self.train_op = self.add_sync_queues_and_barrier(
'post_copy_barrier', [main_fetch])
self.register_callback(RunOp(
self.get_post_init_ops, run_before=True, run_as_trigger=False))
......@@ -189,6 +178,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to the tf.train.Server constructor.")
# TODO use scaffold
class SupervisedSessionCreator(tf.train.SessionCreator):
def __init__(self, is_chief, target):
self.is_chief = is_chief
......@@ -224,14 +214,11 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
shared_name='%s%s' % (name_prefix, i))
for i in range(self.num_worker)]
queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can
# finish this step.
# For each other worker, add an entry in a queue, signaling that it can finish this step.
token = tf.constant(False)
with tf.control_dependencies(enqueue_after_list):
for i, q in enumerate(sync_queues):
if i == self.task_index:
queue_ops.append(tf.no_op())
else:
if i != self.task_index:
queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker.
......@@ -256,7 +243,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
v.name[len(PS_SHADOW_VAR_PREFIX + '/'):])
for i in range(self.nr_gpu):
if i == 0:
name = prefix
name = prefix # no prefix for tower0
else:
name = 'tower%s/%s' % (i, prefix)
if name in local_var_by_name:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment