Commit a2da829b authored by Yuxin Wu's avatar Yuxin Wu

small change in regularizer log

parent e220b436
......@@ -51,14 +51,15 @@ def regularize_cost(regex, func, name='regularize_cost'):
G = tf.get_default_graph()
to_regularize = []
with tf.name_scope('regularize_cost'):
costs = []
for p in params:
para_name = p.name
para_name = p.op.name
if re.search(regex, para_name):
with G.colocate_with(p):
costs.append(func(p))
to_regularize.append(para_name)
to_regularize.append(p.name)
if not costs:
return tf.constant(0, dtype=tf.float32, name='empty_' + name)
......
......@@ -323,7 +323,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
if prefix in realname:
logger.error("[SyncMultiGPUTrainerReplicated] variable "
"{} has its prefix {} appears multiple times in its name!".format(v.name, prefix))
copy_from = var_by_name[realname]
copy_from = var_by_name.get(realname)
assert copy_from is not None, var_by_name.keys()
post_init_ops.append(v.assign(copy_from.read_value()))
logger.info(
"'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops)))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment