Commit 108c9557 authored by Yuxin Wu's avatar Yuxin Wu

fix potential problems with varreplace.freeze_variables

parent ae2bd873
...@@ -52,8 +52,9 @@ def freeze_variables(): ...@@ -52,8 +52,9 @@ def freeze_variables():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained x = FullyConnected('fc', x, 1000) # fc/* will not be trained
""" """
def custom_getter(getter, *args, **kwargs): def custom_getter(getter, *args, **kwargs):
trainable = kwargs.get('trainable', True)
v = getter(*args, **kwargs) v = getter(*args, **kwargs)
if kwargs.pop('trainable', True): if trainable:
v = tf.stop_gradient(v) v = tf.stop_gradient(v)
return v return v
return custom_getter_scope(custom_getter) return custom_getter_scope(custom_getter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment