Commit 89b0c256 authored by Yuxin Wu's avatar Yuxin Wu

also rename remap_get_variable.

parent 0774ec66
...@@ -8,6 +8,10 @@ so you won't need to look at here very often. ...@@ -8,6 +8,10 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here. TensorFlow itself also changes API and those are not listed here.
+ [2017/05/06](https://github.com/ppwwyyxx/tensorpack/commit/0774ec66e66075486f6a36aba63cc2a151b9fec8).
`replace_get_variable` was deprecated in favor of the official `custom_getter` interface.
`{freeze,remap}_get_variable` was renamed to `{freeze,remap}_variables`.
+ [2017/04/09](https://github.com/ppwwyyxx/tensorpack/commit/5beab907895aec36bdcaed62e25b976aad7979b8). + [2017/04/09](https://github.com/ppwwyyxx/tensorpack/commit/5beab907895aec36bdcaed62e25b976aad7979b8).
`ParamRestore` was renamed to `DictRestore`. `ParamRestore` was renamed to `DictRestore`.
+ [2017/03/16](https://github.com/ppwwyyxx/tensorpack/commit/ccae46f4a3ca89dc3df901a338eef8447d19a730). + [2017/03/16](https://github.com/ppwwyyxx/tensorpack/commit/ccae46f4a3ca89dc3df901a338eef8447d19a730).
......
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import remap_get_variable from tensorpack.tfutils.varreplace import remap_variables
from dorefa import get_dorefa from dorefa import get_dorefa
""" """
...@@ -106,7 +106,7 @@ class Model(ModelDesc): ...@@ -106,7 +106,7 @@ class Model(ModelDesc):
def activate(x): def activate(x):
return fa(nonlin(x)) return fa(nonlin(x))
with remap_get_variable(new_get_variable), \ with remap_variables(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \ argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope([Conv2D, FullyConnected], use_bias=False, nl=tf.identity): argscope([Conv2D, FullyConnected], use_bias=False, nl=tf.identity):
logits = (LinearWrap(image) logits = (LinearWrap(image)
......
...@@ -13,7 +13,7 @@ from tensorpack import * ...@@ -13,7 +13,7 @@ from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.utils.stats import RatioCounter from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.varreplace import remap_get_variable from tensorpack.tfutils.varreplace import remap_variables
from dorefa import get_dorefa from dorefa import get_dorefa
""" """
...@@ -90,7 +90,7 @@ class Model(ModelDesc): ...@@ -90,7 +90,7 @@ class Model(ModelDesc):
x = resblock(x, channel, 1) x = resblock(x, channel, 1)
return x return x
with remap_get_variable(new_get_variable), \ with remap_variables(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \ argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity): argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image) logits = (LinearWrap(image)
......
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import remap_get_variable from tensorpack.tfutils.varreplace import remap_variables
import tensorflow as tf import tensorflow as tf
from dorefa import get_dorefa from dorefa import get_dorefa
...@@ -74,7 +74,7 @@ class Model(ModelDesc): ...@@ -74,7 +74,7 @@ class Model(ModelDesc):
image = image / 256.0 image = image / 256.0
with remap_get_variable(binarize_weight), \ with remap_variables(binarize_weight), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \ argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity): argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image) logits = (LinearWrap(image)
......
...@@ -9,7 +9,8 @@ from contextlib import contextmanager ...@@ -9,7 +9,8 @@ from contextlib import contextmanager
from ..utils.develop import deprecated from ..utils.develop import deprecated
__all__ = ['custom_getter_scope', 'replace_get_variable', __all__ = ['custom_getter_scope', 'replace_get_variable',
'freeze_variables', 'freeze_get_variable', 'remap_get_variable'] 'freeze_variables', 'freeze_get_variable', 'remap_get_variable',
'remap_variables']
@contextmanager @contextmanager
...@@ -32,7 +33,7 @@ def replace_get_variable(fn): ...@@ -32,7 +33,7 @@ def replace_get_variable(fn):
return custom_getter_scope(getter) return custom_getter_scope(getter)
def remap_get_variable(fn): def remap_variables(fn):
""" """
Use fn to map the output of any variable getter. Use fn to map the output of any variable getter.
...@@ -61,7 +62,12 @@ def freeze_variables(): ...@@ -61,7 +62,12 @@ def freeze_variables():
with varreplace.freeze_get_variable(): with varreplace.freeze_get_variable():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained x = FullyConnected('fc', x, 1000) # fc/* will not be trained
""" """
return remap_get_variable(lambda v: tf.stop_gradient(v)) return remap_variables(lambda v: tf.stop_gradient(v))
@deprecated("Renamed to remap_variables", "2017-11-06")
def remap_get_variable():
return remap_variables()
@deprecated("Renamed to freeze_variables", "2017-11-06") @deprecated("Renamed to freeze_variables", "2017-11-06")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment