Commit cc2f9c12 authored by Yuxin Wu's avatar Yuxin Wu

fix typo in freeze_variables

parent a77cc508
...@@ -52,14 +52,14 @@ def remap_variables(fn): ...@@ -52,14 +52,14 @@ def remap_variables(fn):
def freeze_variables(): def freeze_variables():
""" """
Return a context, where all variables (reused or not) returned by Return a context, where all variables (reused or not) returned by
``get_variable`` will have no gradients (they will be followed by ``tf.stop_gradient``). ``get_variable`` will have no gradients (they will be wrapped by ``tf.stop_gradient``).
But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get
saved correctly. This is useful to fix certain variables for fine-tuning. saved correctly. This is useful to fix certain variables for fine-tuning.
Example: Example:
.. code-block:: python .. code-block:: python
with varreplace.freeze_get_variable(): with varreplace.freeze_variable():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained x = FullyConnected('fc', x, 1000) # fc/* will not be trained
""" """
return remap_variables(lambda v: tf.stop_gradient(v)) return remap_variables(lambda v: tf.stop_gradient(v))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment