Commit 2cba9434 authored by Yuxin Wu's avatar Yuxin Wu

fix distributed trainer

parent 7780c64b
......@@ -79,7 +79,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0
super(DistributedReplicatedTrainer, self).__init__(config)
super(DistributedTrainerReplicated, self).__init__(config)
@staticmethod
def _average_grads(tower_grads, devices):
......@@ -187,8 +187,9 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
with tf.device(self.param_server_device):
gs = get_global_step_var()
assert gs.device, gs.device
# do this before super.setup because input_source my need global step
super(DistributedReplicatedTrainer, self)._setup()
# do this before inputsource.setup because input_source my need global step
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
with tf.variable_scope(
tf.get_variable_scope(),
......@@ -199,15 +200,15 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source),
devices=self.raw_devices,
vs_names=[True] * self.config.nr_tower) # open vs at each tower
use_vs=[True] * self.config.nr_tower) # open vs at each tower
MultiGPUTrainerBase._check_grad_list(grad_list)
avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices)
avg_grads = DistributedTrainerReplicated._average_grads(grad_list, self.raw_devices)
with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedTrainer._apply_shadow_vars(avg_grads)
ps_var_grads = DistributedTrainerReplicated._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy(grad_list, ps_var_grads)
self._shadow_vars = [v for (_, v) in ps_var_grads]
self._shadow_model_vars = DistributedReplicatedTrainer._shadow_model_variables(self._shadow_vars)
self._shadow_model_vars = DistributedTrainerReplicated._shadow_model_variables(self._shadow_vars)
# TODO add options to synchronize less
main_fetch = tf.group(*var_update_ops, name='main_fetches')
......
......@@ -28,7 +28,7 @@ class FeedfreeTrainerBase(Trainer):
def _setup(self):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
cbs = self._setup_input_source.setup(self.model.get_inputs_desc())
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
......
......@@ -16,7 +16,7 @@ from ..tfutils.gradproc import ScaleGradient
from ..callbacks.graph import RunOp
from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyConstantInput
from .feedfree import FeedfreeTrainerBase
from .base import Trainer
__all__ = ['MultiGPUTrainerBase', 'LeastLoadedDeviceSetter',
'SyncMultiGPUTrainerReplicated',
......@@ -45,7 +45,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
config.data = StagingInputWrapper(config.data, devices)
class MultiGPUTrainerBase(FeedfreeTrainerBase):
class MultiGPUTrainerBase(Trainer):
""" Base class for multi-gpu training"""
@staticmethod
def build_on_multi_tower(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment