Commit ee1af311 authored by Yuxin Wu's avatar Yuxin Wu

comments & fix lint

parent f1e3b3ae
...@@ -6,7 +6,6 @@ import tensorflow as tf ...@@ -6,7 +6,6 @@ import tensorflow as tf
from six.moves import range from six.moves import range
from ..utils import logger from ..utils import logger
from .input_source import StagingInputWrapper
from .feedfree import SingleCostFeedfreeTrainer from .feedfree import SingleCostFeedfreeTrainer
from .multigpu import MultiGPUTrainerBase from .multigpu import MultiGPUTrainerBase
from ..callbacks import RunOp from ..callbacks import RunOp
...@@ -19,18 +18,14 @@ __all__ = ['DistributedReplicatedTrainer'] ...@@ -19,18 +18,14 @@ __all__ = ['DistributedReplicatedTrainer']
PS_SHADOW_VAR_PREFIX = 'ps_var' PS_SHADOW_VAR_PREFIX = 'ps_var'
# To be used with custom_getter on tf.get_variable. Ensures the created variable
# is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
class OverrideToLocalVariableIfNotPsVar(object): class OverrideToLocalVariableIfNotPsVar(object):
"""
# args and kwargs come from the custom_getter interface for Tensorflow Ensures the created variable
# variables, and matches tf.get_variable's signature, with the addition of is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
# 'getter' at the beginning. """
def __call__(self, getter, name, *args, **kwargs): def __call__(self, getter, name, *args, **kwargs):
if name.startswith(PS_SHADOW_VAR_PREFIX): if name.startswith(PS_SHADOW_VAR_PREFIX):
return getter(*args, **kwargs) return getter(*args, **kwargs)
logger.info("CustomGetter-{}".format(name))
if 'collections' in kwargs: if 'collections' in kwargs:
collections = kwargs['collections'] collections = kwargs['collections']
if not collections: if not collections:
...@@ -50,7 +45,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -50,7 +45,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self.cluster = tf.train.ClusterSpec(server_def.cluster) self.cluster = tf.train.ClusterSpec(server_def.cluster)
self.job_name = server_def.job_name self.job_name = server_def.job_name
self.task_index = server_def.task_index self.task_index = server_def.task_index
assert self.job_name in ['ps', 'worker'], job_name assert self.job_name in ['ps', 'worker'], self.job_name
assert tf.test.is_gpu_available
self._input_source = config.data self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker') self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
...@@ -71,14 +67,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -71,14 +67,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)] self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0 self.sync_queue_counter = 0
if self.nr_gpu > 1:
assert tf.test.is_gpu_available()
# TODO staging doesn't work with dummy (require context)
# seem to only improve on >1 GPUs
#if not isinstance(self._input_source, StagingInputWrapper):
#self._input_source = StagingInputWrapper(self._input_source, self.raw_devices)
@staticmethod @staticmethod
def _average_grads(tower_grads, devices): def _average_grads(tower_grads, devices):
""" """
...@@ -134,7 +122,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -134,7 +122,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
list of copy ops list of copy ops
""" """
# TODO do this for each variable separately? # TODO do this for each variable separately?
opt = self.model.get_optimizer() # TODO ensure it in global scope, not local opt = self.model.get_optimizer()
var_update_ops = [] var_update_ops = []
for vid, (g, v) in enumerate(ps_var_grads): for vid, (g, v) in enumerate(ps_var_grads):
apply_gradient_op = opt.apply_gradients([(g, v)]) apply_gradient_op = opt.apply_gradients([(g, v)])
...@@ -175,7 +163,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -175,7 +163,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
var_update_ops = self._apply_gradients_and_copy(grad_list, ps_var_grads) var_update_ops = self._apply_gradients_and_copy(grad_list, ps_var_grads)
main_fetch = tf.group(*var_update_ops, name='main_fetches') main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [main_fetch]) self.train_op = self.add_sync_queues_and_barrier(
'post_copy_barrier', [main_fetch])
self.register_callback(RunOp( self.register_callback(RunOp(
self.get_post_init_ops, run_before=True, run_as_trigger=False)) self.get_post_init_ops, run_before=True, run_as_trigger=False))
...@@ -189,6 +178,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -189,6 +178,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
"Cannot set session_creator or session_config for distributed training! " "Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to the tf.train.Server constructor.") "To use a custom session config, pass it to the tf.train.Server constructor.")
# TODO use scaffold
class SupervisedSessionCreator(tf.train.SessionCreator): class SupervisedSessionCreator(tf.train.SessionCreator):
def __init__(self, is_chief, target): def __init__(self, is_chief, target):
self.is_chief = is_chief self.is_chief = is_chief
...@@ -224,14 +214,11 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -224,14 +214,11 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
shared_name='%s%s' % (name_prefix, i)) shared_name='%s%s' % (name_prefix, i))
for i in range(self.num_worker)] for i in range(self.num_worker)]
queue_ops = [] queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can # For each other worker, add an entry in a queue, signaling that it can finish this step.
# finish this step.
token = tf.constant(False) token = tf.constant(False)
with tf.control_dependencies(enqueue_after_list): with tf.control_dependencies(enqueue_after_list):
for i, q in enumerate(sync_queues): for i, q in enumerate(sync_queues):
if i == self.task_index: if i != self.task_index:
queue_ops.append(tf.no_op())
else:
queue_ops.append(q.enqueue(token)) queue_ops.append(q.enqueue(token))
# Drain tokens off queue for this worker, one for each other worker. # Drain tokens off queue for this worker, one for each other worker.
...@@ -256,7 +243,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -256,7 +243,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
v.name[len(PS_SHADOW_VAR_PREFIX + '/'):]) v.name[len(PS_SHADOW_VAR_PREFIX + '/'):])
for i in range(self.nr_gpu): for i in range(self.nr_gpu):
if i == 0: if i == 0:
name = prefix name = prefix # no prefix for tower0
else: else:
name = 'tower%s/%s' % (i, prefix) name = 'tower%s/%s' % (i, prefix)
if name in local_var_by_name: if name in local_var_by_name:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment