Commit fbb73a8a authored by Yuxin Wu's avatar Yuxin Wu

argscope

parent 23f2ccd6
...@@ -43,15 +43,18 @@ class Model(ModelDesc): ...@@ -43,15 +43,18 @@ class Model(ModelDesc):
tf.image_summary("train_image", image, 10) tf.image_summary("train_image", image, 10)
image = image / 4.0 # just to make range smaller image = image / 4.0 # just to make range smaller
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3) l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3,
nl=BNReLU(is_training), use_bias=False)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training), use_bias=False) l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME') l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')
l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3) l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3,
nl=BNReLU(is_training), use_bias=False)
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training), use_bias=False) l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME') l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')
l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3, padding='VALID') l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3,
padding='VALID', nl=BNReLU(is_training), use_bias=False)
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training), use_bias=False) l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training), use_bias=False)
l = FullyConnected('fc0', l, 1024 + 512, l = FullyConnected('fc0', l, 1024 + 512,
b_init=tf.constant_initializer(0.1)) b_init=tf.constant_initializer(0.1))
...@@ -80,7 +83,7 @@ class Model(ModelDesc): ...@@ -80,7 +83,7 @@ class Model(ModelDesc):
name='regularize_loss') name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary([('.*/W', ['histogram', 'sparsity'])]) # monitor W add_param_summary([('.*/W', ['histogram'])]) # monitor W
return tf.add_n([cost, wd_cost], name='cost') return tf.add_n([cost, wd_cost], name='cost')
def get_data(train_or_test): def get_data(train_or_test):
...@@ -123,7 +126,7 @@ def get_config(): ...@@ -123,7 +126,7 @@ def get_config():
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-2, learning_rate=1e-2,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 20, decay_steps=step_per_epoch * 30 if nr_gpu == 1 else 20,
decay_rate=0.5, staircase=True, name='learning_rate') decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
...@@ -138,7 +141,7 @@ def get_config(): ...@@ -138,7 +141,7 @@ def get_config():
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=500, max_epoch=200,
) )
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
from functools import wraps from functools import wraps
import six import six
from ..tfutils import *
from ..tfutils.modelutils import * from ..tfutils.modelutils import *
from ..tfutils.summary import * from ..tfutils.summary import *
from ..utils import logger from ..utils import logger
...@@ -13,14 +14,13 @@ from ..utils import logger ...@@ -13,14 +14,13 @@ from ..utils import logger
# make sure each layer is only logged once # make sure each layer is only logged once
_layer_logged = set() _layer_logged = set()
def layer_register(summary_activation=False): def layer_register(summary_activation=False, log_shape=True):
""" """
Register a layer. Register a layer.
Args: :param summary_activation: Define the default behavior of whether to
summary_activation:
Define the default behavior of whether to
summary the output(activation) of this layer. summary the output(activation) of this layer.
Can be overriden when creating the layer. Can be overriden when creating the layer.
:param log_shape: log input/output shape of this layer
""" """
def wrapper(func): def wrapper(func):
@wraps(func) @wraps(func)
...@@ -29,13 +29,16 @@ def layer_register(summary_activation=False): ...@@ -29,13 +29,16 @@ def layer_register(summary_activation=False):
assert isinstance(name, six.string_types), \ assert isinstance(name, six.string_types), \
'name must be the first argument. Args: {}'.format(str(args)) 'name must be the first argument. Args: {}'.format(str(args))
args = args[1:] args = args[1:]
do_summary = kwargs.pop( do_summary = kwargs.pop(
'summary_activation', summary_activation) 'summary_activation', summary_activation)
inputs = args[0] inputs = args[0]
actual_args = get_arg_scope()[func.__name__]
actual_args.update(kwargs)
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
outputs = func(*args, **kwargs) outputs = func(*args, **actual_args)
if scope.name not in _layer_logged: if log_shape and scope.name not in _layer_logged:
# log shape info and add activation # log shape info and add activation
logger.info("{} input: {}".format( logger.info("{} input: {}".format(
scope.name, get_shape_str(inputs))) scope.name, get_shape_str(inputs)))
......
...@@ -14,7 +14,7 @@ __all__ = ['BatchNorm'] ...@@ -14,7 +14,7 @@ __all__ = ['BatchNorm']
# TF batch_norm only works for 4D tensor right now: #804 # TF batch_norm only works for 4D tensor right now: #804
# decay: being too close to 1 leads to slow start-up, but ends up better # decay: being too close to 1 leads to slow start-up, but ends up better
# eps: torch: 1e-5. Lasagne: 1e-4 # eps: torch: 1e-5. Lasagne: 1e-4
@layer_register() @layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=True, decay=0.999, epsilon=1e-5): def BatchNorm(x, use_local_stat=True, decay=0.999, epsilon=1e-5):
""" """
Batch normalization layer as described in: Batch normalization layer as described in:
......
...@@ -11,7 +11,7 @@ from .batch_norm import BatchNorm ...@@ -11,7 +11,7 @@ from .batch_norm import BatchNorm
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU'] __all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']
@layer_register() @layer_register(log_shape=False)
def Maxout(x, num_unit): def Maxout(x, num_unit):
""" """
Maxout networks as in `Maxout Networks <http://arxiv.org/abs/1302.4389>`_. Maxout networks as in `Maxout Networks <http://arxiv.org/abs/1302.4389>`_.
...@@ -27,7 +27,7 @@ def Maxout(x, num_unit): ...@@ -27,7 +27,7 @@ def Maxout(x, num_unit):
x = tf.reshape(x, [-1, input_shape[1], input_shape[2], ch / 3, 3]) x = tf.reshape(x, [-1, input_shape[1], input_shape[2], ch / 3, 3])
return tf.reduce_max(x, 4, name='output') return tf.reduce_max(x, 4, name='output')
@layer_register() @layer_register(log_shape=False)
def PReLU(x, init=tf.constant_initializer(0.001), name=None): def PReLU(x, init=tf.constant_initializer(0.001), name=None):
""" """
Parameterized relu as in `Delving Deep into Rectifiers: Surpassing Parameterized relu as in `Delving Deep into Rectifiers: Surpassing
...@@ -44,7 +44,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None): ...@@ -44,7 +44,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
else: else:
return tf.mul(x, 0.5, name=name) return tf.mul(x, 0.5, name=name)
@layer_register() @layer_register(log_shape=False)
def LeakyReLU(x, alpha, name=None): def LeakyReLU(x, alpha, name=None):
""" """
Leaky relu as in `Rectifier Nonlinearities Improve Neural Network Acoustic Leaky relu as in `Rectifier Nonlinearities Improve Neural Network Acoustic
...@@ -66,9 +66,8 @@ def BNReLU(is_training): ...@@ -66,9 +66,8 @@ def BNReLU(is_training):
""" """
:returns: a activation function that performs BN + ReLU (a too common combination) :returns: a activation function that performs BN + ReLU (a too common combination)
""" """
def f(x, name=None): def BNReLU(x, name=None):
with tf.variable_scope('bn'): x = BatchNorm('bn', x, is_training)
x = BatchNorm.f(x, is_training)
x = tf.nn.relu(x, name=name) x = tf.nn.relu(x, name=name)
return x return x
return f return BNReLU
...@@ -14,4 +14,5 @@ def _global_import(name): ...@@ -14,4 +14,5 @@ def _global_import(name):
_global_import('sessinit') _global_import('sessinit')
_global_import('common') _global_import('common')
_global_import('gradproc') _global_import('gradproc')
_global_import('argscope')
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: argscope.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from contextlib import contextmanager
from collections import defaultdict
import inspect
import copy
import six
__all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack = []
@contextmanager
def argscope(layers, **kwargs):
param = kwargs
if not isinstance(layers, list):
layers = [layers]
def _check_args_exist(l):
args = inspect.getargspec(l).args
for k, v in six.iteritems(param):
assert k in args, "No argument {} in {}".format(k, l.__name__)
for l in layers:
assert hasattr(l, 'f'), "{} is not a registered layer".format(l.__name__)
_check_args_exist(l.f)
new_scope = copy.copy(get_arg_scope())
for l in layers:
new_scope[l.__name__].update(param)
_ArgScopeStack.append(new_scope)
yield
del _ArgScopeStack[-1]
def get_arg_scope():
if len(_ArgScopeStack) > 0:
return _ArgScopeStack[-1]
else:
return defaultdict(dict)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment