Commit 584e9cd4 authored by Yuxin Wu's avatar Yuxin Wu

Fix hyperparam and optimizer issue in distributed trainer (#431)

parent c8028236
...@@ -66,7 +66,7 @@ class GraphVarParam(HyperParam): ...@@ -66,7 +66,7 @@ class GraphVarParam(HyperParam):
def setup_graph(self): def setup_graph(self):
""" Will setup the assign operator for that variable. """ """ Will setup the assign operator for that variable. """
all_vars = tf.global_variables() all_vars = tf.all_variables()
for v in all_vars: for v in all_vars:
if v.name == self.var_name: if v.name == self.var_name:
self.var = v self.var = v
......
...@@ -204,6 +204,8 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -204,6 +204,8 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
cbs = self._input_source.setup(self.model.get_inputs_desc()) cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs) self.config.callbacks.extend(cbs)
# build the optimizer first, before entering any tower
self.model.get_optimizer()
# Ngpu * Nvar * 2 # Ngpu * Nvar * 2
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, self.config.tower,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment