Commit 06bb5142 authored by Yuxin Wu's avatar Yuxin Wu

fix SimpleTrainer. expose setup_graph for AsyncMultiGPUTrainer

parent 8837d748
......@@ -332,25 +332,37 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
self._scale_gradient = scale_gradient
super(AsyncMultiGPUTrainer, self).__init__(config)
def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup()
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
@staticmethod
def setup_graph(model, input, scale_gradient, tower):
"""
Args:
model (ModelDesc):
input (InputSource):
scale_gradient (bool):
tower (list[int]):
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower]
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source), devices)
tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input), devices)
MultiGPUTrainerBase._check_grad_list(grad_list)
if self._scale_gradient and self.config.nr_tower > 1:
if scale_gradient and len(tower) > 1:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradproc = ScaleGradient(('.*', 1.0 / self.config.nr_tower), verbose=False)
gradproc = ScaleGradient(('.*', 1.0 / len(tower)), verbose=False)
grad_list = [gradproc.process(gv) for gv in grad_list]
# Ngpu x Nvar x 2
train_ops = []
opt = self.model.get_optimizer()
opt = model.get_optimizer()
for i, grad_and_vars in enumerate(zip(*grad_list)):
# Ngpu x 2
v = grad_and_vars[0][1]
......@@ -358,4 +370,9 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
# will call apply_gradients (therefore gradproc) multiple times
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(i)))
self.train_op = tf.group(*train_ops, name='train_op')
return tf.group(*train_ops, name='train_op'), input.get_callbacks()
def _setup(self):
self.train_op, cbs = AsyncMultiGPUTrainer.setup_graph(
self.model, self._input_source, self._scale_gradient, self.config.tower)
self.config.callbacks.extend(cbs)
......@@ -36,7 +36,7 @@ class SimpleTrainer(Trainer):
super(SimpleTrainer, self).__init__(config)
@staticmethod
def setup_graph(self, model, input):
def setup_graph(model, input):
"""
Setup graph for simple trainer.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment