Commit 5adf1f73 authored by Yuxin Wu's avatar Yuxin Wu

freeze get_variable

parent b26a945e
# tensorpack # tensorpack
Neural Network Toolbox on TensorFlow Neural Network Toolbox on TensorFlow
See some [examples](examples) to learn about the framework: Docs & tutorials should be ready within a month. See some [examples](examples) to learn about the framework:
### Vision: ### Vision:
+ [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net) + [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net)
......
...@@ -7,7 +7,9 @@ import tensorflow as tf ...@@ -7,7 +7,9 @@ import tensorflow as tf
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from contextlib import contextmanager from contextlib import contextmanager
__all__ = ['replace_get_variable'] __all__ = ['replace_get_variable', 'freeze_get_variable']
_ORIG_GET_VARIABLE = tf.get_variable
@contextmanager @contextmanager
...@@ -20,3 +22,16 @@ def replace_get_variable(fn): ...@@ -20,3 +22,16 @@ def replace_get_variable(fn):
yield yield
tf.get_variable = old_getv tf.get_variable = old_getv
variable_scope.get_variable = old_vars_getv variable_scope.get_variable = old_vars_getv
def freeze_get_variable():
"""
Return a contextmanager, where all variables returned by
`get_variable` will have no gradients.
"""
old_get_variable = tf.get_variable
def fn(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
return tf.stop_gradient(v)
return replace_get_variable(fn)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment