Commit afa11399 authored by Yuxin Wu's avatar Yuxin Wu

Fix distributed trainer (fix #671)

parent 5626e04d
...@@ -104,8 +104,11 @@ class Monitors(Callback): ...@@ -104,8 +104,11 @@ class Monitors(Callback):
You should use `trainer.monitors` for logging and it will dispatch your You should use `trainer.monitors` for logging and it will dispatch your
logs to each sub-monitor. logs to each sub-monitor.
""" """
_chief_only = False
def __init__(self, monitors): def __init__(self, monitors):
self._scalar_history = ScalarHistory() self._scalar_history = ScalarHistory().set_chief_only(False)
self._monitors = monitors + [self._scalar_history] self._monitors = monitors + [self._scalar_history]
for m in self._monitors: for m in self._monitors:
assert isinstance(m, TrainingMonitor), m assert isinstance(m, TrainingMonitor), m
...@@ -325,6 +328,9 @@ class ScalarPrinter(TrainingMonitor): ...@@ -325,6 +328,9 @@ class ScalarPrinter(TrainingMonitor):
""" """
Print scalar data into terminal. Print scalar data into terminal.
""" """
_chief_only = False
def __init__(self, enable_step=False, enable_epoch=True, def __init__(self, enable_step=False, enable_epoch=True,
whitelist=None, blacklist=None): whitelist=None, blacklist=None):
""" """
......
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
import re import re
from six.moves import range from six.moves import range
from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.common import get_op_tensor_name, get_global_step_var from ..tfutils.common import get_op_tensor_name, get_global_step_var
...@@ -230,19 +231,26 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase): ...@@ -230,19 +231,26 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
Returns: Returns:
list of (shadow_model_var, local_model_var) used for syncing. list of (shadow_model_var, local_model_var) used for syncing.
""" """
G = tf.get_default_graph()
curr_shadow_vars = set([v.name for v in shadow_vars]) curr_shadow_vars = set([v.name for v in shadow_vars])
model_vars = tf.model_variables() model_vars = tf.model_variables()
shadow_model_vars = [] shadow_model_vars = []
for v in model_vars: for v in model_vars:
assert v.name.startswith('tower'), "Found some MODEL_VARIABLES created outside of the model!" assert v.name.startswith('tower'), "Found some MODEL_VARIABLES created outside of the tower function!"
stripped_name = get_op_tensor_name(re.sub('tower[0-9]+/', '', v.name))[0] stripped_op_name, stripped_var_name = get_op_tensor_name(re.sub('^tower[0-9]+/', '', v.name))
if stripped_name in curr_shadow_vars: if stripped_op_name in curr_shadow_vars:
continue continue
new_v = tf.get_variable(stripped_name, dtype=v.dtype.base_dtype, try:
G.get_tensor_by_name(stripped_var_name)
logger.warn("Model Variable {} also appears in other collections.".format(stripped_var_name))
continue
except KeyError:
pass
new_v = tf.get_variable(stripped_op_name, dtype=v.dtype.base_dtype,
initializer=v.initial_value, initializer=v.initial_value,
trainable=False) trainable=False)
curr_shadow_vars.add(stripped_name) # avoid duplicated shadow_model_vars curr_shadow_vars.add(stripped_op_name) # avoid duplicated shadow_model_vars
shadow_vars.append(new_v) shadow_vars.append(new_v)
shadow_model_vars.append((new_v, v)) # only need to sync model_var from one tower shadow_model_vars.append((new_v, v)) # only need to sync model_var from one tower
return shadow_model_vars return shadow_model_vars
...@@ -279,7 +287,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase): ...@@ -279,7 +287,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
use_vs=[True] * len(self.towers)) # open vs at each tower use_vs=[True] * len(self.towers)) # open vs at each tower
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
avg_grads = average_grads(grad_list, devices=self.raw_devices) avg_grads = average_grads(
grad_list, colocation=False, devices=self.raw_devices)
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads) ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy( var_update_ops = self._apply_gradients_and_copy(
......
...@@ -56,7 +56,7 @@ def override_to_local_variable(enable=True): ...@@ -56,7 +56,7 @@ def override_to_local_variable(enable=True):
ns = orig_vs.original_name_scope ns = orig_vs.original_name_scope
with tf.variable_scope( with tf.variable_scope(
orig_vs, custom_getter=custom_getter): orig_vs, custom_getter=custom_getter):
with tf.name_scope(ns + '/'): with tf.name_scope(ns + '/' if ns else ''):
yield yield
else: else:
yield yield
......
...@@ -143,6 +143,9 @@ class Trainer(object): ...@@ -143,6 +143,9 @@ class Trainer(object):
Args: Args:
cb (Callback or [Callback]): a callback or a list of callbacks cb (Callback or [Callback]): a callback or a list of callbacks
Returns:
succeed or not
""" """
if isinstance(cb, (list, tuple)): if isinstance(cb, (list, tuple)):
for x in cb: for x in cb:
...@@ -153,8 +156,10 @@ class Trainer(object): ...@@ -153,8 +156,10 @@ class Trainer(object):
"Cannot register more callbacks after trainer was setup!" "Cannot register more callbacks after trainer was setup!"
if not self.is_chief and cb.chief_only: if not self.is_chief and cb.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(cb))) logger.warn("Callback {} is chief-only, skipped.".format(str(cb)))
return False
else: else:
self._callbacks.append(cb) self._callbacks.append(cb)
return True
register_callback = _register_callback register_callback = _register_callback
...@@ -188,9 +193,11 @@ class Trainer(object): ...@@ -188,9 +193,11 @@ class Trainer(object):
self.register_callback(cb) self.register_callback(cb)
for cb in self._callbacks: for cb in self._callbacks:
assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!" assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!"
registered_monitors = []
for m in monitors: for m in monitors:
self.register_callback(m) if self.register_callback(m):
self.monitors = Monitors(monitors) registered_monitors.append(m)
self.monitors = Monitors(registered_monitors)
self.register_callback(self.monitors) # monitors is also a callback self.register_callback(self.monitors) # monitors is also a callback
# some final operations that might modify the graph # some final operations that might modify the graph
......
...@@ -214,7 +214,7 @@ class DistributedTrainerParameterServer(DistributedTrainerBase): ...@@ -214,7 +214,7 @@ class DistributedTrainerParameterServer(DistributedTrainerBase):
return [] return []
class DistributedTrainerReplicated(SingleCostTrainer): class DistributedTrainerReplicated(DistributedTrainerBase):
__doc__ = DistributedReplicatedBuilder.__doc__ __doc__ = DistributedReplicatedBuilder.__doc__
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment