Commit 2a1af832 authored by Yuxin Wu's avatar Yuxin Wu

only freeze trainable variables (fix #351)

parent 2d5b44e6
...@@ -51,7 +51,7 @@ def remap_variables(fn): ...@@ -51,7 +51,7 @@ def remap_variables(fn):
def freeze_variables(): def freeze_variables():
""" """
Return a context, where all variables (reused or not) returned by Return a context, where all trainable variables (reused or not) returned by
``get_variable`` will have no gradients (they will be wrapped by ``tf.stop_gradient``). ``get_variable`` will have no gradients (they will be wrapped by ``tf.stop_gradient``).
But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get
saved correctly. This is useful to fix certain variables for fine-tuning. saved correctly. This is useful to fix certain variables for fine-tuning.
...@@ -62,7 +62,12 @@ def freeze_variables(): ...@@ -62,7 +62,12 @@ def freeze_variables():
with varreplace.freeze_variable(): with varreplace.freeze_variable():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained x = FullyConnected('fc', x, 1000) # fc/* will not be trained
""" """
return remap_variables(lambda v: tf.stop_gradient(v)) def custom_getter(getter, *args, **kwargs):
v = getter(*args, **kwargs)
if kwargs.pop('trainable', True):
v = tf.stop_gradient(v)
return v
return custom_getter_scope(custom_getter)
@deprecated("Renamed to remap_variables", "2017-11-06") @deprecated("Renamed to remap_variables", "2017-11-06")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment