Commit 610ffe3f authored by Yuxin Wu's avatar Yuxin Wu

freeze_variable: don't add to collection if not originally trainable

parent 9f4154e8
......@@ -110,7 +110,9 @@ def freeze_variables(stop_gradient=True, skip_collection=False):
if skip_collection:
kwargs['trainable'] = False
v = getter(*args, **kwargs)
if skip_collection:
# do not perform unnecessary changes if it's not originally trainable
# otherwise the variable may get added to MODEL_VARIABLES twice
if trainable and skip_collection:
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if trainable and stop_gradient:
v = tf.stop_gradient(v, name='freezed_' + name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment