Commit 0774ec66 authored by Yuxin Wu's avatar Yuxin Wu

use custom_getter for varreplace, instead of hacks

parent 17126868
...@@ -57,7 +57,7 @@ class Model(ModelDesc): ...@@ -57,7 +57,7 @@ class Model(ModelDesc):
old_get_variable = tf.get_variable old_get_variable = tf.get_variable
# monkey-patch tf.get_variable to apply fw # monkey-patch tf.get_variable to apply fw
def new_get_variable(v): def binarize_weight(v):
name = v.op.name name = v.op.name
# don't binarize first and last layer # don't binarize first and last layer
if not name.endswith('W') or 'conv0' in name or 'fc' in name: if not name.endswith('W') or 'conv0' in name or 'fc' in name:
...@@ -74,7 +74,7 @@ class Model(ModelDesc): ...@@ -74,7 +74,7 @@ class Model(ModelDesc):
image = image / 256.0 image = image / 256.0
with remap_get_variable(new_get_variable), \ with remap_get_variable(binarize_weight), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \ argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity): argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image) logits = (LinearWrap(image)
......
...@@ -4,57 +4,54 @@ ...@@ -4,57 +4,54 @@
# Credit: Qinyao He # Credit: Qinyao He
import tensorflow as tf import tensorflow as tf
from tensorflow.python.ops import variable_scope
from contextlib import contextmanager from contextlib import contextmanager
__all__ = ['replace_get_variable', 'freeze_get_variable', 'remap_get_variable'] from ..utils.develop import deprecated
_ORIG_GET_VARIABLE = tf.get_variable __all__ = ['custom_getter_scope', 'replace_get_variable',
'freeze_variables', 'freeze_get_variable', 'remap_get_variable']
@contextmanager @contextmanager
def custom_getter_scope(custom_getter):
scope = tf.get_variable_scope()
with tf.variable_scope(scope, custom_getter=custom_getter):
yield
@deprecated("Use custom_getter_scope instead.", "2017-11-06")
def replace_get_variable(fn): def replace_get_variable(fn):
""" """
Args: Args:
fn: a function compatible with ``tf.get_variable``. fn: a function compatible with ``tf.get_variable``.
Returns: Returns:
a context where ``tf.get_variable`` and a context with a custom getter
``variable_scope.get_variable`` are replaced with ``fn``.
Note that originally ``tf.get_variable ==
tensorflow.python.ops.variable_scope.get_variable``. But some code such as
some in `rnn_cell/`, uses the latter one to get variable, therefore both
need to be replaced.
""" """
old_getv = tf.get_variable def getter(_, *args, **kwargs):
old_vars_getv = variable_scope.get_variable return fn(*args, **kwargs)
return custom_getter_scope(getter)
tf.get_variable = fn
# doesn't seem to be working?
# and when it works, remap might call fn twice
variable_scope.get_variable = fn
yield
tf.get_variable = old_getv
variable_scope.get_variable = old_vars_getv
def remap_get_variable(fn): def remap_get_variable(fn):
""" Similar to :func:`replace_get_variable`, but the function `fn`
takes the variable returned by the original `tf.get_variable` call
and return a tensor.
""" """
old_getv = tf.get_variable Use fn to map the output of any variable getter.
Args:
fn (tf.Variable -> tf.Tensor)
def new_get_variable(name, shape=None, **kwargs): Returns:
v = old_getv(name, shape, **kwargs) a context where all the variables will be mapped by fn.
"""
def custom_getter(getter, *args, **kwargs):
v = getter(*args, **kwargs)
return fn(v) return fn(v)
return replace_get_variable(new_get_variable) return custom_getter_scope(custom_getter)
def freeze_get_variable(): def freeze_variables():
""" """
Return a context, where all variables (reused or not) returned by Return a context, where all variables (reused or not) returned by
``get_variable`` will have no gradients (surrounded by ``tf.stop_gradient``). ``get_variable`` will have no gradients (they will be followed by ``tf.stop_gradient``).
But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get But they will still be in ``TRAINABLE_VARIABLES`` collections so they will get
saved correctly. This is useful to fix certain variables for fine-tuning. saved correctly. This is useful to fix certain variables for fine-tuning.
...@@ -64,5 +61,9 @@ def freeze_get_variable(): ...@@ -64,5 +61,9 @@ def freeze_get_variable():
with varreplace.freeze_get_variable(): with varreplace.freeze_get_variable():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained x = FullyConnected('fc', x, 1000) # fc/* will not be trained
""" """
return remap_get_variable( return remap_get_variable(lambda v: tf.stop_gradient(v))
lambda v: tf.stop_gradient(v))
@deprecated("Renamed to freeze_variables", "2017-11-06")
def freeze_get_variable():
return freeze_variables()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment