Commit f67a1aff authored by Yuxin Wu's avatar Yuxin Wu

improve logging and fix wd computation in inference

parent 12846f57
...@@ -123,3 +123,6 @@ class LeastLoadedDeviceSetter(object): ...@@ -123,3 +123,6 @@ class LeastLoadedDeviceSetter(object):
self.ps_sizes[device_index] += var_size self.ps_sizes[device_index] += var_size
return sanitize_name(device_name) return sanitize_name(device_name)
def __str__(self):
return "LeastLoadedDeviceSetter-{}".format(self.worker_device)
...@@ -42,6 +42,11 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -42,6 +42,11 @@ def regularize_cost(regex, func, name='regularize_cost'):
cost = cost + regularize_cost("fc.*/W", l2_regularizer(1e-5)) cost = cost + regularize_cost("fc.*/W", l2_regularizer(1e-5))
""" """
ctx = get_current_tower_context() ctx = get_current_tower_context()
if not ctx.is_training:
# Currently cannot build the wd_cost correctly at inference,
# because ths vs_name used in inference can be '', therefore the
# variable filter will fail
return tf.constant(0, dtype=tf.float32, name='empty_' + name)
params = tf.trainable_variables() params = tf.trainable_variables()
# If vars are shared, use all of them # If vars are shared, use all of them
...@@ -89,6 +94,12 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -89,6 +94,12 @@ def regularize_cost_from_collection(name='regularize_cost'):
""" """
regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context() ctx = get_current_tower_context()
if not ctx.is_training:
# Currently cannot build the wd_cost correctly at inference,
# because ths vs_name used in inference can be '', therefore the
# variable filter will fail
return None
if len(regularization_losses) > 0: if len(regularization_losses) > 0:
# NOTE: this collection doesn't grow with towers. # NOTE: this collection doesn't grow with towers.
# It is only added with variables that are newly created. # It is only added with variables that are newly created.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment