Commit afa11399 authored by Yuxin Wu's avatar Yuxin Wu

Fix distributed trainer (fix #671)

parent 5626e04d
......@@ -104,8 +104,11 @@ class Monitors(Callback):
You should use `trainer.monitors` for logging and it will dispatch your
logs to each sub-monitor.
"""
_chief_only = False
def __init__(self, monitors):
self._scalar_history = ScalarHistory()
self._scalar_history = ScalarHistory().set_chief_only(False)
self._monitors = monitors + [self._scalar_history]
for m in self._monitors:
assert isinstance(m, TrainingMonitor), m
......@@ -325,6 +328,9 @@ class ScalarPrinter(TrainingMonitor):
"""
Print scalar data into terminal.
"""
_chief_only = False
def __init__(self, enable_step=False, enable_epoch=True,
whitelist=None, blacklist=None):
"""
......
......@@ -6,6 +6,7 @@ import tensorflow as tf
import re
from six.moves import range
from ..utils import logger
from ..utils.argtools import memoized
from ..tfutils.common import get_op_tensor_name, get_global_step_var
......@@ -230,19 +231,26 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
Returns:
list of (shadow_model_var, local_model_var) used for syncing.
"""
G = tf.get_default_graph()
curr_shadow_vars = set([v.name for v in shadow_vars])
model_vars = tf.model_variables()
shadow_model_vars = []
for v in model_vars:
assert v.name.startswith('tower'), "Found some MODEL_VARIABLES created outside of the model!"
stripped_name = get_op_tensor_name(re.sub('tower[0-9]+/', '', v.name))[0]
if stripped_name in curr_shadow_vars:
assert v.name.startswith('tower'), "Found some MODEL_VARIABLES created outside of the tower function!"
stripped_op_name, stripped_var_name = get_op_tensor_name(re.sub('^tower[0-9]+/', '', v.name))
if stripped_op_name in curr_shadow_vars:
continue
new_v = tf.get_variable(stripped_name, dtype=v.dtype.base_dtype,
try:
G.get_tensor_by_name(stripped_var_name)
logger.warn("Model Variable {} also appears in other collections.".format(stripped_var_name))
continue
except KeyError:
pass
new_v = tf.get_variable(stripped_op_name, dtype=v.dtype.base_dtype,
initializer=v.initial_value,
trainable=False)
curr_shadow_vars.add(stripped_name) # avoid duplicated shadow_model_vars
curr_shadow_vars.add(stripped_op_name) # avoid duplicated shadow_model_vars
shadow_vars.append(new_v)
shadow_model_vars.append((new_v, v)) # only need to sync model_var from one tower
return shadow_model_vars
......@@ -279,7 +287,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
use_vs=[True] * len(self.towers)) # open vs at each tower
DataParallelBuilder._check_grad_list(grad_list)
avg_grads = average_grads(grad_list, devices=self.raw_devices)
avg_grads = average_grads(
grad_list, colocation=False, devices=self.raw_devices)
with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy(
......
......@@ -56,7 +56,7 @@ def override_to_local_variable(enable=True):
ns = orig_vs.original_name_scope
with tf.variable_scope(
orig_vs, custom_getter=custom_getter):
with tf.name_scope(ns + '/'):
with tf.name_scope(ns + '/' if ns else ''):
yield
else:
yield
......
......@@ -143,6 +143,9 @@ class Trainer(object):
Args:
cb (Callback or [Callback]): a callback or a list of callbacks
Returns:
succeed or not
"""
if isinstance(cb, (list, tuple)):
for x in cb:
......@@ -153,8 +156,10 @@ class Trainer(object):
"Cannot register more callbacks after trainer was setup!"
if not self.is_chief and cb.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(cb)))
return False
else:
self._callbacks.append(cb)
return True
register_callback = _register_callback
......@@ -188,9 +193,11 @@ class Trainer(object):
self.register_callback(cb)
for cb in self._callbacks:
assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!"
registered_monitors = []
for m in monitors:
self.register_callback(m)
self.monitors = Monitors(monitors)
if self.register_callback(m):
registered_monitors.append(m)
self.monitors = Monitors(registered_monitors)
self.register_callback(self.monitors) # monitors is also a callback
# some final operations that might modify the graph
......
......@@ -214,7 +214,7 @@ class DistributedTrainerParameterServer(DistributedTrainerBase):
return []
class DistributedTrainerReplicated(SingleCostTrainer):
class DistributedTrainerReplicated(DistributedTrainerBase):
__doc__ = DistributedReplicatedBuilder.__doc__
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment