Commit fbb73a8a authored by Yuxin Wu's avatar Yuxin Wu

argscope

parent 23f2ccd6
...@@ -43,15 +43,18 @@ class Model(ModelDesc): ...@@ -43,15 +43,18 @@ class Model(ModelDesc):
tf.image_summary("train_image", image, 10) tf.image_summary("train_image", image, 10)
image = image / 4.0 # just to make range smaller image = image / 4.0 # just to make range smaller
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3) l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3,
nl=BNReLU(is_training), use_bias=False)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training), use_bias=False) l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME') l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')
l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3) l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3,
nl=BNReLU(is_training), use_bias=False)
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training), use_bias=False) l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME') l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')
l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3, padding='VALID') l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3,
padding='VALID', nl=BNReLU(is_training), use_bias=False)
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training), use_bias=False) l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training), use_bias=False)
l = FullyConnected('fc0', l, 1024 + 512, l = FullyConnected('fc0', l, 1024 + 512,
b_init=tf.constant_initializer(0.1)) b_init=tf.constant_initializer(0.1))
...@@ -80,7 +83,7 @@ class Model(ModelDesc): ...@@ -80,7 +83,7 @@ class Model(ModelDesc):
name='regularize_loss') name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary([('.*/W', ['histogram', 'sparsity'])]) # monitor W add_param_summary([('.*/W', ['histogram'])]) # monitor W
return tf.add_n([cost, wd_cost], name='cost') return tf.add_n([cost, wd_cost], name='cost')
def get_data(train_or_test): def get_data(train_or_test):
...@@ -123,7 +126,7 @@ def get_config(): ...@@ -123,7 +126,7 @@ def get_config():
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-2, learning_rate=1e-2,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 20, decay_steps=step_per_epoch * 30 if nr_gpu == 1 else 20,
decay_rate=0.5, staircase=True, name='learning_rate') decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
...@@ -138,7 +141,7 @@ def get_config(): ...@@ -138,7 +141,7 @@ def get_config():
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=500, max_epoch=200,
) )
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
from functools import wraps from functools import wraps
import six import six
from ..tfutils import *
from ..tfutils.modelutils import * from ..tfutils.modelutils import *
from ..tfutils.summary import * from ..tfutils.summary import *
from ..utils import logger from ..utils import logger
...@@ -13,14 +14,13 @@ from ..utils import logger ...@@ -13,14 +14,13 @@ from ..utils import logger
# make sure each layer is only logged once # make sure each layer is only logged once
_layer_logged = set() _layer_logged = set()
def layer_register(summary_activation=False): def layer_register(summary_activation=False, log_shape=True):
""" """
Register a layer. Register a layer.
Args: :param summary_activation: Define the default behavior of whether to
summary_activation: summary the output(activation) of this layer.
Define the default behavior of whether to Can be overriden when creating the layer.
summary the output(activation) of this layer. :param log_shape: log input/output shape of this layer
Can be overriden when creating the layer.
""" """
def wrapper(func): def wrapper(func):
@wraps(func) @wraps(func)
...@@ -29,13 +29,16 @@ def layer_register(summary_activation=False): ...@@ -29,13 +29,16 @@ def layer_register(summary_activation=False):
assert isinstance(name, six.string_types), \ assert isinstance(name, six.string_types), \
'name must be the first argument. Args: {}'.format(str(args)) 'name must be the first argument. Args: {}'.format(str(args))
args = args[1:] args = args[1:]
do_summary = kwargs.pop( do_summary = kwargs.pop(
'summary_activation', summary_activation) 'summary_activation', summary_activation)
inputs = args[0] inputs = args[0]
actual_args = get_arg_scope()[func.__name__]
actual_args.update(kwargs)
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
outputs = func(*args, **kwargs) outputs = func(*args, **actual_args)
if scope.name not in _layer_logged: if log_shape and scope.name not in _layer_logged:
# log shape info and add activation # log shape info and add activation
logger.info("{} input: {}".format( logger.info("{} input: {}".format(
scope.name, get_shape_str(inputs))) scope.name, get_shape_str(inputs)))
......
...@@ -14,7 +14,7 @@ __all__ = ['BatchNorm'] ...@@ -14,7 +14,7 @@ __all__ = ['BatchNorm']
# TF batch_norm only works for 4D tensor right now: #804 # TF batch_norm only works for 4D tensor right now: #804
# decay: being too close to 1 leads to slow start-up, but ends up better # decay: being too close to 1 leads to slow start-up, but ends up better
# eps: torch: 1e-5. Lasagne: 1e-4 # eps: torch: 1e-5. Lasagne: 1e-4
@layer_register() @layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=True, decay=0.999, epsilon=1e-5): def BatchNorm(x, use_local_stat=True, decay=0.999, epsilon=1e-5):
""" """
Batch normalization layer as described in: Batch normalization layer as described in:
......
...@@ -11,7 +11,7 @@ from .batch_norm import BatchNorm ...@@ -11,7 +11,7 @@ from .batch_norm import BatchNorm
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU'] __all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']
@layer_register() @layer_register(log_shape=False)
def Maxout(x, num_unit): def Maxout(x, num_unit):
""" """
Maxout networks as in `Maxout Networks <http://arxiv.org/abs/1302.4389>`_. Maxout networks as in `Maxout Networks <http://arxiv.org/abs/1302.4389>`_.
...@@ -27,7 +27,7 @@ def Maxout(x, num_unit): ...@@ -27,7 +27,7 @@ def Maxout(x, num_unit):
x = tf.reshape(x, [-1, input_shape[1], input_shape[2], ch / 3, 3]) x = tf.reshape(x, [-1, input_shape[1], input_shape[2], ch / 3, 3])
return tf.reduce_max(x, 4, name='output') return tf.reduce_max(x, 4, name='output')
@layer_register() @layer_register(log_shape=False)
def PReLU(x, init=tf.constant_initializer(0.001), name=None): def PReLU(x, init=tf.constant_initializer(0.001), name=None):
""" """
Parameterized relu as in `Delving Deep into Rectifiers: Surpassing Parameterized relu as in `Delving Deep into Rectifiers: Surpassing
...@@ -44,7 +44,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None): ...@@ -44,7 +44,7 @@ def PReLU(x, init=tf.constant_initializer(0.001), name=None):
else: else:
return tf.mul(x, 0.5, name=name) return tf.mul(x, 0.5, name=name)
@layer_register() @layer_register(log_shape=False)
def LeakyReLU(x, alpha, name=None): def LeakyReLU(x, alpha, name=None):
""" """
Leaky relu as in `Rectifier Nonlinearities Improve Neural Network Acoustic Leaky relu as in `Rectifier Nonlinearities Improve Neural Network Acoustic
...@@ -66,9 +66,8 @@ def BNReLU(is_training): ...@@ -66,9 +66,8 @@ def BNReLU(is_training):
""" """
:returns: a activation function that performs BN + ReLU (a too common combination) :returns: a activation function that performs BN + ReLU (a too common combination)
""" """
def f(x, name=None): def BNReLU(x, name=None):
with tf.variable_scope('bn'): x = BatchNorm('bn', x, is_training)
x = BatchNorm.f(x, is_training)
x = tf.nn.relu(x, name=name) x = tf.nn.relu(x, name=name)
return x return x
return f return BNReLU
...@@ -14,4 +14,5 @@ def _global_import(name): ...@@ -14,4 +14,5 @@ def _global_import(name):
_global_import('sessinit') _global_import('sessinit')
_global_import('common') _global_import('common')
_global_import('gradproc') _global_import('gradproc')
_global_import('argscope')
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: argscope.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from contextlib import contextmanager
from collections import defaultdict
import inspect
import copy
import six
__all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack = []
@contextmanager
def argscope(layers, **kwargs):
param = kwargs
if not isinstance(layers, list):
layers = [layers]
def _check_args_exist(l):
args = inspect.getargspec(l).args
for k, v in six.iteritems(param):
assert k in args, "No argument {} in {}".format(k, l.__name__)
for l in layers:
assert hasattr(l, 'f'), "{} is not a registered layer".format(l.__name__)
_check_args_exist(l.f)
new_scope = copy.copy(get_arg_scope())
for l in layers:
new_scope[l.__name__].update(param)
_ArgScopeStack.append(new_scope)
yield
del _ArgScopeStack[-1]
def get_arg_scope():
if len(_ArgScopeStack) > 0:
return _ArgScopeStack[-1]
else:
return defaultdict(dict)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment