Commit 5b18f8be authored by Yuxin Wu's avatar Yuxin Wu

bugfix in replicated training varscope

parent 3ab6d2b0
...@@ -43,6 +43,7 @@ class ModelSaver(Callback): ...@@ -43,6 +43,7 @@ class ModelSaver(Callback):
vars = [] vars = []
for key in self.var_collections: for key in self.var_collections:
vars.extend(tf.get_collection(key)) vars.extend(tf.get_collection(key))
vars = list(set(vars))
self.path = os.path.join(self.checkpoint_dir, 'model') self.path = os.path.join(self.checkpoint_dir, 'model')
if get_tf_version_number() <= 1.1: if get_tf_version_number() <= 1.1:
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
......
...@@ -43,6 +43,8 @@ class TowerContext(object): ...@@ -43,6 +43,8 @@ class TowerContext(object):
assert self._name assert self._name
if vs_name is None: if vs_name is None:
self._vs_name = self._name self._vs_name = self._name
else:
self._vs_name = vs_name
else: else:
assert vs_name is None, "vs_name is only valid in 'replicated' mode!" assert vs_name is None, "vs_name is only valid in 'replicated' mode!"
self._vs_name = '' self._vs_name = ''
......
...@@ -259,7 +259,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -259,7 +259,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
lambda: self._get_cost_and_grad()[1], lambda: self._get_cost_and_grad()[1],
var_strategy='replicated', var_strategy='replicated',
# use no variable scope for the first tower # use no variable scope for the first tower
vs_names=[''] + [None] * self.config.nr_tower - 1) vs_names=[''] + [None] * (self.config.nr_tower - 1))
grads = self._allreduce_grads(grad_list) grads = self._allreduce_grads(grad_list)
train_ops = [] train_ops = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment