Commit 108c9557 authored by Yuxin Wu's avatar Yuxin Wu

fix potential problems with varreplace.freeze_variables

parent ae2bd873
......@@ -52,8 +52,9 @@ def freeze_variables():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained
"""
def custom_getter(getter, *args, **kwargs):
trainable = kwargs.get('trainable', True)
v = getter(*args, **kwargs)
if kwargs.pop('trainable', True):
if trainable:
v = tf.stop_gradient(v)
return v
return custom_getter_scope(custom_getter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment