Commit 5b18f8be authored by Yuxin Wu's avatar Yuxin Wu

bugfix in replicated training varscope

parent 3ab6d2b0
......@@ -43,6 +43,7 @@ class ModelSaver(Callback):
vars = []
for key in self.var_collections:
vars.extend(tf.get_collection(key))
vars = list(set(vars))
self.path = os.path.join(self.checkpoint_dir, 'model')
if get_tf_version_number() <= 1.1:
self.saver = tf.train.Saver(
......
......@@ -43,6 +43,8 @@ class TowerContext(object):
assert self._name
if vs_name is None:
self._vs_name = self._name
else:
self._vs_name = vs_name
else:
assert vs_name is None, "vs_name is only valid in 'replicated' mode!"
self._vs_name = ''
......
......@@ -259,7 +259,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
lambda: self._get_cost_and_grad()[1],
var_strategy='replicated',
# use no variable scope for the first tower
vs_names=[''] + [None] * self.config.nr_tower - 1)
vs_names=[''] + [None] * (self.config.nr_tower - 1))
grads = self._allreduce_grads(grad_list)
train_ops = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment