Commit 44a2c531 authored by Yuxin Wu's avatar Yuxin Wu

add remap_get_variable

parent c59987a3
......@@ -19,7 +19,8 @@ Docs & tutorials should be ready within a month. See some [examples](examples) t
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/A3C-Gym)
### Unsupervised Learning:
+ [Generative Adversarial Network(GAN) variants, including DCGAN, Image2Image, InfoGAN](examples/GAN)
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, Image to Image.
### Speech / NLP:
+ [LSTM-CTC for speech recognition](examples/CTC-TIMIT)
......
......@@ -15,7 +15,7 @@ import sys
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import replace_get_variable
from tensorpack.tfutils.varreplace import remap_get_variable
from dorefa import get_dorefa
"""
......@@ -87,10 +87,10 @@ class Model(ModelDesc):
old_get_variable = tf.get_variable
# monkey-patch tf.get_variable to apply fw
def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
def new_get_variable(v):
name = v.op.name
# don't binarize first and last layer
if name != 'W' or 'conv0' in v.op.name or 'fct' in v.op.name:
if not name.endswith('W') or 'conv0' in name or 'fct' in name:
return v
else:
logger.info("Binarizing weight {}".format(v.op.name))
......@@ -104,7 +104,7 @@ class Model(ModelDesc):
def activate(x):
return fa(nonlin(x))
with replace_get_variable(new_get_variable), \
with remap_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope([Conv2D, FullyConnected], use_bias=False, nl=tf.identity):
logits = (LinearWrap(image)
......
......@@ -13,7 +13,7 @@ from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.varreplace import replace_get_variable
from tensorpack.tfutils.varreplace import remap_get_variable
from dorefa import get_dorefa
"""
......@@ -44,10 +44,10 @@ class Model(ModelDesc):
fw, fa, fg = get_dorefa(BITW, BITA, BITG)
old_get_variable = tf.get_variable
def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
def new_get_variable(v):
name = v.op.name
# don't binarize first and last layer
if name != 'W' or 'conv1' in v.op.name or 'fct' in v.op.name:
if not name.endswith('W') or 'conv1' in name or 'fct' in name:
return v
else:
logger.info("Binarizing weight {}".format(v.op.name))
......@@ -90,7 +90,7 @@ class Model(ModelDesc):
x = resblock(x, channel, 1)
return x
with replace_get_variable(new_get_variable), \
with remap_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image)
......
......@@ -11,7 +11,7 @@ import os
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import replace_get_variable
from tensorpack.tfutils.varreplace import remap_get_variable
from dorefa import get_dorefa
"""
......@@ -56,10 +56,10 @@ class Model(ModelDesc):
old_get_variable = tf.get_variable
# monkey-patch tf.get_variable to apply fw
def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
def new_get_variable(v):
name = v.op.name
# don't binarize first and last layer
if name != 'W' or 'conv0' in v.op.name or 'fc' in v.op.name:
if not name.endswith('W') or 'conv0' in name or 'fc' in name:
return v
else:
logger.info("Binarizing weight {}".format(v.op.name))
......@@ -73,7 +73,7 @@ class Model(ModelDesc):
image = image / 256.0
with replace_get_variable(new_get_variable), \
with remap_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image)
......
......@@ -15,7 +15,6 @@ import cv2
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as CFG, use_global_argument
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, GANModelDesc
......@@ -30,14 +29,14 @@ The original code (dcgan.torch) uses kernel_shape=4, but I found the difference
./DCGAN-CelebA.py --load path/to/model --sample
"""
CFG.SHAPE = 64
CFG.BATCH = 128
CFG.Z_DIM = 100
SHAPE = 64
BATCH = 128
Z_DIM = 100
class Model(GANModelDesc):
def _get_inputs(self):
return [InputVar(tf.float32, (None, CFG.SHAPE, CFG.SHAPE, 3), 'input')]
return [InputVar(tf.float32, (None, SHAPE, SHAPE, 3), 'input')]
def generator(self, z):
""" return a image generated from z"""
......@@ -73,8 +72,8 @@ class Model(GANModelDesc):
image_pos = inputs[0]
image_pos = image_pos / 128.0 - 1
z = tf.random_uniform([CFG.BATCH, CFG.Z_DIM], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, CFG.Z_DIM], name='z')
z = tf.random_uniform([BATCH, Z_DIM], -1, 1, name='z_train')
z = tf.placeholder_with_default(z, [None, Z_DIM], name='z')
with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.truncated_normal_initializer(stddev=0.02)):
......@@ -91,12 +90,13 @@ class Model(GANModelDesc):
def get_data():
datadir = CFG.data
global args
datadir = args.data
imgs = glob.glob(datadir + '/*.jpg')
ds = ImageFromFile(imgs, channel=3, shuffle=True)
augs = [imgaug.CenterCrop(140), imgaug.Resize(64)]
ds = AugmentImageComponent(ds, augs)
ds = BatchData(ds, CFG.BATCH)
ds = BatchData(ds, BATCH)
ds = PrefetchDataZMQ(ds, 1)
return ds
......@@ -137,7 +137,6 @@ if __name__ == '__main__':
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset')
args = parser.parse_args()
use_global_argument(args)
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.sample:
......
......@@ -7,7 +7,7 @@ import tensorflow as tf
from tensorflow.python.ops import variable_scope
from contextlib import contextmanager
__all__ = ['replace_get_variable', 'freeze_get_variable']
__all__ = ['replace_get_variable', 'freeze_get_variable', 'remap_get_variable']
_ORIG_GET_VARIABLE = tf.get_variable
......@@ -16,7 +16,7 @@ _ORIG_GET_VARIABLE = tf.get_variable
def replace_get_variable(fn):
"""
Args:
fn: a function taking the same arguments as ``tf.get_variable``.
fn: a function compatible with ``tf.get_variable``.
Returns:
a context where ``tf.get_variable`` and
``variable_scope.get_variable`` are replaced with ``fn``.
......@@ -36,6 +36,19 @@ def replace_get_variable(fn):
variable_scope.get_variable = old_vars_getv
def remap_get_variable(fn):
""" Similar to :func:`replace_get_variable`, but the function `fn`
takes the variable returned by the original `tf.get_variable` call
and return a tensor.
"""
old_getv = tf.get_variable
def new_get_variable(name, shape=None, **kwargs):
v = old_getv(name, shape, **kwargs)
return fn(v)
return replace_get_variable(new_get_variable)
def freeze_get_variable():
"""
Return a context, where all variables (reused or not) returned by
......@@ -49,9 +62,5 @@ def freeze_get_variable():
with varreplace.freeze_get_variable():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained
"""
old_get_variable = tf.get_variable
def fn(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
return tf.stop_gradient(v)
return replace_get_variable(fn)
return remap_get_variable(
lambda v: tf.stop_gradient(v))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment