Commit ea0342e5 authored by Yuxin Wu's avatar Yuxin Wu

Let replicated trainer sync untrainable variables as well.

parent 6d7276b8
...@@ -298,24 +298,32 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): ...@@ -298,24 +298,32 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
@staticmethod @staticmethod
def get_post_init_ops(): def get_post_init_ops():
# Copy initialized values for variables on GPU 0 to other GPUs. """
all_vars = tf.trainable_variables() Copy values of variables on GPU 0 to other GPUs.
all_vars.extend(tf.model_variables()) """
# literally all variables, because it's better to sync optimizer-internal variables as well
all_vars = tf.global_variables() + tf.local_variables()
var_by_name = dict([(v.name, v) for v in all_vars]) var_by_name = dict([(v.name, v) for v in all_vars])
post_init_ops = [] post_init_ops = []
for v in all_vars: for v in all_vars:
split_name = v.name.split('/')
if not v.name.startswith('tower'): if not v.name.startswith('tower'):
continue continue
if v.name.startswith('tower0'): if v.name.startswith('tower0'):
logger.warn("[SyncMultiGPUTrainerReplicated] variable "
"{} has prefix 'tower0', this is unexpected.".format(v.name))
continue # TODO some vars (EMA) may still startswith tower0 continue # TODO some vars (EMA) may still startswith tower0
# in this trainer, the master name doesn't have the towerx/ prefix # in this trainer, the master name doesn't have the towerx/ prefix
split_name = split_name[1:] split_name = v.name.split('/')
copy_from = var_by_name['/'.join(split_name)] prefix = split_name[0]
realname = '/'.join(split_name[1:])
if prefix in realname:
logger.error("[SyncMultiGPUTrainerReplicated] variable "
"{} has its prefix {} appears multiple times in its name!".format(v.name, prefix))
copy_from = var_by_name[realname]
post_init_ops.append(v.assign(copy_from.read_value())) post_init_ops.append(v.assign(copy_from.read_value()))
logger.info( logger.info(
"'sync_variables_from_tower0' includes {} operations.".format(len(post_init_ops))) "'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops)))
return tf.group(*post_init_ops, name='sync_variables_from_tower0') return tf.group(*post_init_ops, name='sync_variables_from_main_tower')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase): class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment