Commit a2da829b authored by Yuxin Wu's avatar Yuxin Wu

small change in regularizer log

parent e220b436
...@@ -51,14 +51,15 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -51,14 +51,15 @@ def regularize_cost(regex, func, name='regularize_cost'):
G = tf.get_default_graph() G = tf.get_default_graph()
to_regularize = [] to_regularize = []
with tf.name_scope('regularize_cost'): with tf.name_scope('regularize_cost'):
costs = [] costs = []
for p in params: for p in params:
para_name = p.name para_name = p.op.name
if re.search(regex, para_name): if re.search(regex, para_name):
with G.colocate_with(p): with G.colocate_with(p):
costs.append(func(p)) costs.append(func(p))
to_regularize.append(para_name) to_regularize.append(p.name)
if not costs: if not costs:
return tf.constant(0, dtype=tf.float32, name='empty_' + name) return tf.constant(0, dtype=tf.float32, name='empty_' + name)
......
...@@ -323,7 +323,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): ...@@ -323,7 +323,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
if prefix in realname: if prefix in realname:
logger.error("[SyncMultiGPUTrainerReplicated] variable " logger.error("[SyncMultiGPUTrainerReplicated] variable "
"{} has its prefix {} appears multiple times in its name!".format(v.name, prefix)) "{} has its prefix {} appears multiple times in its name!".format(v.name, prefix))
copy_from = var_by_name[realname] copy_from = var_by_name.get(realname)
assert copy_from is not None, var_by_name.keys()
post_init_ops.append(v.assign(copy_from.read_value())) post_init_ops.append(v.assign(copy_from.read_value()))
logger.info( logger.info(
"'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops))) "'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops)))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment