Commit 44a2c531 authored by Yuxin Wu's avatar Yuxin Wu

add remap_get_variable

parent c59987a3
...@@ -19,7 +19,8 @@ Docs & tutorials should be ready within a month. See some [examples](examples) t ...@@ -19,7 +19,8 @@ Docs & tutorials should be ready within a month. See some [examples](examples) t
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/A3C-Gym) + [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/A3C-Gym)
### Unsupervised Learning: ### Unsupervised Learning:
+ [Generative Adversarial Network(GAN) variants, including DCGAN, Image2Image, InfoGAN](examples/GAN) + [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, Image to Image.
### Speech / NLP: ### Speech / NLP:
+ [LSTM-CTC for speech recognition](examples/CTC-TIMIT) + [LSTM-CTC for speech recognition](examples/CTC-TIMIT)
......
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import replace_get_variable from tensorpack.tfutils.varreplace import remap_get_variable
from dorefa import get_dorefa from dorefa import get_dorefa
""" """
...@@ -87,10 +87,10 @@ class Model(ModelDesc): ...@@ -87,10 +87,10 @@ class Model(ModelDesc):
old_get_variable = tf.get_variable old_get_variable = tf.get_variable
# monkey-patch tf.get_variable to apply fw # monkey-patch tf.get_variable to apply fw
def new_get_variable(name, shape=None, **kwargs): def new_get_variable(v):
v = old_get_variable(name, shape, **kwargs) name = v.op.name
# don't binarize first and last layer # don't binarize first and last layer
if name != 'W' or 'conv0' in v.op.name or 'fct' in v.op.name: if not name.endswith('W') or 'conv0' in name or 'fct' in name:
return v return v
else: else:
logger.info("Binarizing weight {}".format(v.op.name)) logger.info("Binarizing weight {}".format(v.op.name))
...@@ -104,7 +104,7 @@ class Model(ModelDesc): ...@@ -104,7 +104,7 @@ class Model(ModelDesc):
def activate(x): def activate(x):
return fa(nonlin(x)) return fa(nonlin(x))
with replace_get_variable(new_get_variable), \ with remap_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \ argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope([Conv2D, FullyConnected], use_bias=False, nl=tf.identity): argscope([Conv2D, FullyConnected], use_bias=False, nl=tf.identity):
logits = (LinearWrap(image) logits = (LinearWrap(image)
......
...@@ -13,7 +13,7 @@ from tensorpack import * ...@@ -13,7 +13,7 @@ from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.utils.stats import RatioCounter from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.varreplace import replace_get_variable from tensorpack.tfutils.varreplace import remap_get_variable
from dorefa import get_dorefa from dorefa import get_dorefa
""" """
...@@ -44,10 +44,10 @@ class Model(ModelDesc): ...@@ -44,10 +44,10 @@ class Model(ModelDesc):
fw, fa, fg = get_dorefa(BITW, BITA, BITG) fw, fa, fg = get_dorefa(BITW, BITA, BITG)
old_get_variable = tf.get_variable old_get_variable = tf.get_variable
def new_get_variable(name, shape=None, **kwargs): def new_get_variable(v):
v = old_get_variable(name, shape, **kwargs) name = v.op.name
# don't binarize first and last layer # don't binarize first and last layer
if name != 'W' or 'conv1' in v.op.name or 'fct' in v.op.name: if not name.endswith('W') or 'conv1' in name or 'fct' in name:
return v return v
else: else:
logger.info("Binarizing weight {}".format(v.op.name)) logger.info("Binarizing weight {}".format(v.op.name))
...@@ -90,7 +90,7 @@ class Model(ModelDesc): ...@@ -90,7 +90,7 @@ class Model(ModelDesc):
x = resblock(x, channel, 1) x = resblock(x, channel, 1)
return x return x
with replace_get_variable(new_get_variable), \ with remap_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \ argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity): argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image) logits = (LinearWrap(image)
......
...@@ -11,7 +11,7 @@ import os ...@@ -11,7 +11,7 @@ import os
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import replace_get_variable from tensorpack.tfutils.varreplace import remap_get_variable
from dorefa import get_dorefa from dorefa import get_dorefa
""" """
...@@ -56,10 +56,10 @@ class Model(ModelDesc): ...@@ -56,10 +56,10 @@ class Model(ModelDesc):
old_get_variable = tf.get_variable old_get_variable = tf.get_variable
# monkey-patch tf.get_variable to apply fw # monkey-patch tf.get_variable to apply fw
def new_get_variable(name, shape=None, **kwargs): def new_get_variable(v):
v = old_get_variable(name, shape, **kwargs) name = v.op.name
# don't binarize first and last layer # don't binarize first and last layer
if name != 'W' or 'conv0' in v.op.name or 'fc' in v.op.name: if not name.endswith('W') or 'conv0' in name or 'fc' in name:
return v return v
else: else:
logger.info("Binarizing weight {}".format(v.op.name)) logger.info("Binarizing weight {}".format(v.op.name))
...@@ -73,7 +73,7 @@ class Model(ModelDesc): ...@@ -73,7 +73,7 @@ class Model(ModelDesc):
image = image / 256.0 image = image / 256.0
with replace_get_variable(new_get_variable), \ with remap_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \ argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity): argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image) logits = (LinearWrap(image)
......
...@@ -15,7 +15,6 @@ import cv2 ...@@ -15,7 +15,6 @@ import cv2
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as CFG, use_global_argument
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, GANModelDesc from GAN import GANTrainer, RandomZData, GANModelDesc
...@@ -30,14 +29,14 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference ...@@ -30,14 +29,14 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference
./DCGAN-CelebA.py --load path/to/model --sample ./DCGAN-CelebA.py --load path/to/model --sample
""" """
CFG.SHAPE = 64 SHAPE = 64
CFG.BATCH = 128 BATCH = 128
CFG.Z_DIM = 100 Z_DIM = 100
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputVar(tf.float32, (None, CFG.SHAPE, CFG.SHAPE, 3), 'input')] return [InputVar(tf.float32, (None, SHAPE, SHAPE, 3), 'input')]
def generator(self, z): def generator(self, z):
""" return a image generated from z""" """ return a image generated from z"""
...@@ -73,8 +72,8 @@ class Model(GANModelDesc): ...@@ -73,8 +72,8 @@ class Model(GANModelDesc):
image_pos = inputs[0] image_pos = inputs[0]
image_pos = image_pos / 128.0 - 1 image_pos = image_pos / 128.0 - 1
z = tf.random_uniform([CFG.BATCH, CFG.Z_DIM], -1, 1, name='z_train') z = tf.random_uniform([BATCH, Z_DIM], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, CFG.Z_DIM], name='z') z = tf.placeholder_with_default(z, [None, Z_DIM], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)): W_init=tf.truncated_normal_initializer(stddev=0.02)):
...@@ -91,12 +90,13 @@ class Model(GANModelDesc): ...@@ -91,12 +90,13 @@ class Model(GANModelDesc):
def get_data(): def get_data():
datadir = CFG.data global args
datadir = args.data
imgs = glob.glob(datadir + '/*.jpg') imgs = glob.glob(datadir + '/*.jpg')
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
augs = [imgaug.CenterCrop(140), imgaug.Resize(64)] augs = [imgaug.CenterCrop(140), imgaug.Resize(64)]
ds = AugmentImageComponent(ds, augs) ds = AugmentImageComponent(ds, augs)
ds = BatchData(ds, CFG.BATCH) ds = BatchData(ds, BATCH)
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchDataZMQ(ds, 1)
return ds return ds
...@@ -137,7 +137,6 @@ if __name__ == '__main__': ...@@ -137,7 +137,6 @@ if __name__ == '__main__':
parser.add_argument('--sample', action='store_true', help='run sampling') parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset') parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset')
args = parser.parse_args() args = parser.parse_args()
use_global_argument(args)
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample: if args.sample:
......
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from contextlib import contextmanager from contextlib import contextmanager
__all__ = ['replace_get_variable', 'freeze_get_variable'] __all__ = ['replace_get_variable', 'freeze_get_variable', 'remap_get_variable']
_ORIG_GET_VARIABLE = tf.get_variable _ORIG_GET_VARIABLE = tf.get_variable
...@@ -16,7 +16,7 @@ _ORIG_GET_VARIABLE = tf.get_variable ...@@ -16,7 +16,7 @@ _ORIG_GET_VARIABLE = tf.get_variable
def replace_get_variable(fn): def replace_get_variable(fn):
""" """
Args: Args:
fn: a function taking the same arguments as ``tf.get_variable``. fn: a function compatible with ``tf.get_variable``.
Returns: Returns:
a context where ``tf.get_variable`` and a context where ``tf.get_variable`` and
``variable_scope.get_variable`` are replaced with ``fn``. ``variable_scope.get_variable`` are replaced with ``fn``.
...@@ -36,6 +36,19 @@ def replace_get_variable(fn): ...@@ -36,6 +36,19 @@ def replace_get_variable(fn):
variable_scope.get_variable = old_vars_getv variable_scope.get_variable = old_vars_getv
def remap_get_variable(fn):
""" Similar to :func:`replace_get_variable`, but the function `fn`
takes the variable returned by the original `tf.get_variable` call
and return a tensor.
"""
old_getv = tf.get_variable
def new_get_variable(name, shape=None, **kwargs):
v = old_getv(name, shape, **kwargs)
return fn(v)
return replace_get_variable(new_get_variable)
def freeze_get_variable(): def freeze_get_variable():
""" """
Return a context, where all variables (reused or not) returned by Return a context, where all variables (reused or not) returned by
...@@ -49,9 +62,5 @@ def freeze_get_variable(): ...@@ -49,9 +62,5 @@ def freeze_get_variable():
with varreplace.freeze_get_variable(): with varreplace.freeze_get_variable():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained x = FullyConnected('fc', x, 1000) # fc/* will not be trained
""" """
old_get_variable = tf.get_variable return remap_get_variable(
lambda v: tf.stop_gradient(v))
def fn(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
return tf.stop_gradient(v)
return replace_get_variable(fn)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment