Commit 86d4c589 authored by Yuxin Wu's avatar Yuxin Wu

avoid setting reuse completely to True

parent 2ef9669d
...@@ -44,8 +44,10 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -44,8 +44,10 @@ class MultiGPUTrainer(QueueInputTrainer):
len(self.config.tower))) len(self.config.tower)))
grad_list = [] grad_list = []
global_scope = tf.get_variable_scope()
for idx, t in enumerate(self.config.tower): for idx, t in enumerate(self.config.tower):
with tf.device('/gpu:{}'.format(t)), \ with tf.device('/gpu:{}'.format(t)), \
tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)) as scope: TowerContext('tower{}'.format(idx)) as scope:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue model_inputs = self._get_model_inputs() # each tower dequeue from input queue
...@@ -60,7 +62,6 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -60,7 +62,6 @@ class MultiGPUTrainer(QueueInputTrainer):
if idx == 0: if idx == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables()
# avoid repeated summary from each device # avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS) backup = backup_collection(SUMMARY_BACKUP_KEYS)
restore_collection(backup) restore_collection(backup)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment