Commit 880b767d authored by Yuxin Wu's avatar Yuxin Wu

add NCHW for Conv2D,Pooling and BN. fix #150

parent 4be41492
...@@ -21,9 +21,9 @@ This implementation uses the variants proposed in: ...@@ -21,9 +21,9 @@ This implementation uses the variants proposed in:
Identity Mappings in Deep Residual Networks, arxiv:1603.05027 Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results on 2 TitanX for I can reproduce the results on 2 TitanX for
n=5, about 7.1% val error after 67k steps (15 step/s) n=5, about 7.1% val error after 67k steps (20.4 step/s)
n=18, about 5.95% val error after 80k steps (4.2 step/s) n=18, about 5.95% val error after 80k steps (5.6 step/s)
n=30: a 182-layer network, about 5.6% val error after 51k steps (2.5 step/s) n=30: a 182-layer network, about 5.6% val error after 51k steps (3.4 step/s)
This model uses the whole training set instead of a train-val split. This model uses the whole training set instead of a train-val split.
To train: To train:
...@@ -47,10 +47,11 @@ class Model(ModelDesc): ...@@ -47,10 +47,11 @@ class Model(ModelDesc):
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
image = image / 128.0 - 1 image = image / 128.0 - 1
image = tf.transpose(image, [0, 3, 1, 2])
def residual(name, l, increase_dim=False, first=False): def residual(name, l, increase_dim=False, first=False):
shape = l.get_shape().as_list() shape = l.get_shape().as_list()
in_channel = shape[3] in_channel = shape[1]
if increase_dim: if increase_dim:
out_channel = in_channel * 2 out_channel = in_channel * 2
...@@ -65,12 +66,14 @@ class Model(ModelDesc): ...@@ -65,12 +66,14 @@ class Model(ModelDesc):
c2 = Conv2D('conv2', c1, out_channel) c2 = Conv2D('conv2', c1, out_channel)
if increase_dim: if increase_dim:
l = AvgPooling('pool', l, 2) l = AvgPooling('pool', l, 2)
l = tf.pad(l, [[0, 0], [0, 0], [0, 0], [in_channel // 2, in_channel // 2]]) l = tf.pad(l, [[0, 0], [in_channel // 2, in_channel // 2], [0, 0], [0, 0]])
l = c2 + l l = c2 + l
return l return l
with argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3, with argscope([Conv2D, AvgPooling, BatchNorm, GlobalAvgPooling],
data_format='NCHW'), \
argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3,
W_init=variance_scaling_initializer(mode='FAN_OUT')): W_init=variance_scaling_initializer(mode='FAN_OUT')):
l = Conv2D('conv0', image, 16, nl=BNReLU) l = Conv2D('conv0', image, 16, nl=BNReLU)
l = residual('res1.0', l, first=True) l = residual('res1.0', l, first=True)
......
...@@ -52,7 +52,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -52,7 +52,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
with tf.variable_scope(tf.get_variable_scope(), reuse=False): with tf.variable_scope(tf.get_variable_scope(), reuse=False):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused # BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740 with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
# TODO if reuse=True, try to find and use the existing statistics # if reuse=True, try to find and use the existing statistics
# how to use multiple tensors to update one EMA? seems impossbile # how to use multiple tensors to update one EMA? seems impossbile
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
...@@ -71,7 +71,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -71,7 +71,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
var_var_name = ema.average_name(batch_var) var_var_name = ema.average_name(batch_var)
if ctx.is_main_tower: if ctx.is_main_tower:
# main tower, but needs to use global stat. global stat must be from outside # main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the desired variable name could # when reuse=True, the desired variable name could
# actually be different, because a different var is created # actually be different, because a different var is created
# for different reuse tower # for different reuse tower
ema_mean = tf.get_variable('mean/' + emaname, [n_out]) ema_mean = tf.get_variable('mean/' + emaname, [n_out])
...@@ -96,14 +96,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -96,14 +96,7 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x, ema_mean, ema_var, beta, gamma, epsilon, 'output') x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
def get_bn_variables(x, use_scale, use_bias): def get_bn_variables(n_out, use_scale, use_bias):
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
n_out = shape[-1] # channel
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
if use_bias: if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer()) beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
else: else:
...@@ -118,11 +111,10 @@ def get_bn_variables(x, use_scale, use_bias): ...@@ -118,11 +111,10 @@ def get_bn_variables(x, use_scale, use_bias):
initializer=tf.constant_initializer(), trainable=False) initializer=tf.constant_initializer(), trainable=False)
moving_var = tf.get_variable('variance/EMA', [n_out], moving_var = tf.get_variable('variance/EMA', [n_out],
initializer=tf.constant_initializer(), trainable=False) initializer=tf.constant_initializer(), trainable=False)
return x, beta, gamma, moving_mean, moving_var return beta, gamma, moving_mean, moving_var
def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay): def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
# TODO update it later (similar to slim) might be faster?
# TODO is there a way to use zero_debias in multi-GPU? # TODO is there a way to use zero_debias in multi-GPU?
update_op1 = moving_averages.assign_moving_average( update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False, moving_mean, batch_mean, decay, zero_debias=False,
...@@ -138,7 +130,7 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay): ...@@ -138,7 +130,7 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True): use_scale=True, use_bias=True, data_format='NHWC'):
""" """
Batch Normalization layer, as described in the paper: Batch Normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by `Batch Normalization: Accelerating Deep Network Training by
...@@ -171,7 +163,18 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -171,7 +163,18 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
with the official inceptionv3 example). with the official inceptionv3 example).
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
x, beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias) assert len(shape) in [2, 4]
if len(shape) == 2:
data_format = 'NCHW'
if data_format == 'NCHW':
n_out = shape[1]
else:
n_out = shape[-1] # channel
if len(shape) == 2:
x = tf.reshape(x, [-1, n_out, 1, 1])
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias)
ctx = get_current_tower_context() ctx = get_current_tower_context()
if use_local_stat is None: if use_local_stat is None:
...@@ -182,22 +185,25 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -182,22 +185,25 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
logger.warn("[BatchNorm] use_local_stat != is_training") logger.warn("[BatchNorm] use_local_stat != is_training")
if use_local_stat: if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta, xn, batch_mean, batch_var = tf.nn.fused_batch_norm(
epsilon=epsilon, is_training=True) x, gamma, beta, epsilon=epsilon,
is_training=True, data_format=data_format)
else: else:
assert not ctx.is_training, "In training, local statistics has to be used!" assert not ctx.is_training, "In training, local statistics has to be used!"
# fused seems slower in inference if data_format == 'NCHW':
# xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta, # fused is slower in inference, but support NCHW
# moving_mean, moving_var, xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
# epsilon=epsilon, is_training=False, name='output') moving_mean, moving_var,
xn = tf.nn.batch_normalization( epsilon=epsilon, is_training=False, data_format=data_format)
x, moving_mean, moving_var, beta, gamma, epsilon) else:
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
if len(shape) == 2: if len(shape) == 2:
xn = tf.squeeze(xn, [1, 2]) axis = [2, 3] if data_format == 'NCHW' else [1, 2]
xn = tf.squeeze(xn, axis)
# maintain EMA only on one GPU. # maintain EMA only on one GPU.
# TODO the first GPU already has too many work, might be faster to update it on a different GPU
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) return update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else: else:
...@@ -231,7 +237,11 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -231,7 +237,11 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
x, beta, gamma, moving_mean, moving_var = get_bn_variables(x, use_scale, use_bias) assert len(shape) in [2, 4]
n_out = shape[-1]
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias)
ctx = get_current_tower_context() ctx = get_current_tower_context()
use_local_stat = ctx.is_training use_local_stat = ctx.is_training
......
...@@ -14,12 +14,13 @@ __all__ = ['Conv2D', 'Deconv2D'] ...@@ -14,12 +14,13 @@ __all__ = ['Conv2D', 'Deconv2D']
def Conv2D(x, out_channel, kernel_shape, def Conv2D(x, out_channel, kernel_shape,
padding='SAME', stride=1, padding='SAME', stride=1,
W_init=None, b_init=None, W_init=None, b_init=None,
nl=tf.identity, split=1, use_bias=True): nl=tf.identity, split=1, use_bias=True,
data_format='NHWC'):
""" """
2D convolution on 4D inputs. 2D convolution on 4D inputs.
Args: Args:
x (tf.Tensor): a tensor of shape NHWC. x (tf.Tensor): a 4D tensor.
Must have known number of channels, but can have other unknown dimensions. Must have known number of channels, but can have other unknown dimensions.
out_channel (int): number of output channel. out_channel (int): number of output channel.
kernel_shape: (h, w) tuple or a int. kernel_shape: (h, w) tuple or a int.
...@@ -32,7 +33,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -32,7 +33,7 @@ def Conv2D(x, out_channel, kernel_shape,
use_bias (bool): whether to use bias. use_bias (bool): whether to use bias.
Returns: Returns:
tf.Tensor: a NHWC tensor named ``output``. tf.Tensor named ``output``.
Variable Names: Variable Names:
...@@ -40,7 +41,8 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -40,7 +41,8 @@ def Conv2D(x, out_channel, kernel_shape,
* ``b``: bias * ``b``: bias
""" """
in_shape = x.get_shape().as_list() in_shape = x.get_shape().as_list()
in_channel = in_shape[-1] channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis]
assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!" assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
assert in_channel % split == 0 assert in_channel % split == 0
assert out_channel % split == 0 assert out_channel % split == 0
...@@ -48,7 +50,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -48,7 +50,7 @@ def Conv2D(x, out_channel, kernel_shape,
kernel_shape = shape2d(kernel_shape) kernel_shape = shape2d(kernel_shape)
padding = padding.upper() padding = padding.upper()
filter_shape = kernel_shape + [in_channel / split, out_channel] filter_shape = kernel_shape + [in_channel / split, out_channel]
stride = shape4d(stride) stride = shape4d(stride, data_format=data_format)
if W_init is None: if W_init is None:
W_init = tf.contrib.layers.variance_scaling_initializer() W_init = tf.contrib.layers.variance_scaling_initializer()
...@@ -60,14 +62,14 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -60,14 +62,14 @@ def Conv2D(x, out_channel, kernel_shape,
b = tf.get_variable('b', [out_channel], initializer=b_init) b = tf.get_variable('b', [out_channel], initializer=b_init)
if split == 1: if split == 1:
conv = tf.nn.conv2d(x, W, stride, padding) conv = tf.nn.conv2d(x, W, stride, padding, data_format=data_format)
else: else:
inputs = tf.split(x, split, 3) inputs = tf.split(x, split, channel_axis)
kernels = tf.split(W, split, 3) kernels = tf.split(W, split, 3)
outputs = [tf.nn.conv2d(i, k, stride, padding) outputs = [tf.nn.conv2d(i, k, stride, padding, data_format=data_format)
for i, k in zip(inputs, kernels)] for i, k in zip(inputs, kernels)]
conv = tf.concat(outputs, 3) conv = tf.concat(outputs, channel_axis)
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output') return nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
class StaticDynamicShape(object): class StaticDynamicShape(object):
......
...@@ -14,59 +14,58 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling', ...@@ -14,59 +14,58 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample'] 'BilinearUpSample']
def _Pooling(func, x, shape, stride, padding, data_format):
padding = padding.upper()
shape = shape4d(shape, data_format=data_format)
if stride is None:
stride = shape
else:
stride = shape4d(stride, data_format=data_format)
return func(x, ksize=shape,
strides=stride, padding=padding,
data_format=data_format,
name='output')
@layer_register() @layer_register()
def MaxPooling(x, shape, stride=None, padding='VALID'): def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
""" """
Max Pooling on 4D tensors. Max Pooling on 4D tensors.
Args: Args:
x (tf.Tensor): a NHWC tensor. x (tf.Tensor): a 4D tensor.
shape: int or (h, w) tuple shape: int or (h, w) tuple
stride: int or (h, w) tuple. Defaults to be the same as shape. stride: int or (h, w) tuple. Defaults to be the same as shape.
padding (str): 'valid' or 'same'. padding (str): 'valid' or 'same'.
Returns: Returns:
tf.Tensor: a NHWC tensor named ``output``. tf.Tensor named ``output``.
""" """
padding = padding.upper() return _Pooling(tf.nn.max_pool, x, shape, stride, padding,
shape = shape4d(shape) data_format=data_format)
if stride is None:
stride = shape
else:
stride = shape4d(stride)
return tf.nn.max_pool(x, ksize=shape,
strides=stride, padding=padding,
name='output')
@layer_register() @layer_register()
def AvgPooling(x, shape, stride=None, padding='VALID'): def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
""" """
Average Pooling on 4D tensors. Average Pooling on 4D tensors.
Args: Args:
x (tf.Tensor): a NHWC tensor. x (tf.Tensor): a 4D tensor.
shape: int or (h, w) tuple shape: int or (h, w) tuple
stride: int or (h, w) tuple. Defaults to be the same as shape. stride: int or (h, w) tuple. Defaults to be the same as shape.
padding (str): 'valid' or 'same'. padding (str): 'valid' or 'same'.
Returns: Returns:
tf.Tensor: a NHWC tensor named ``output``. tf.Tensor named ``output``.
""" """
padding = padding.upper() return _Pooling(tf.nn.avg_pool, x, shape, stride, padding,
shape = shape4d(shape) data_format=data_format)
if stride is None:
stride = shape
else:
stride = shape4d(stride)
return tf.nn.avg_pool(x, ksize=shape,
strides=stride, padding=padding, name='output')
@layer_register() @layer_register()
def GlobalAvgPooling(x): def GlobalAvgPooling(x, data_format='NHWC'):
""" """
Global average pooling as in the paper `Network In Network Global average pooling as in the paper `Network In Network
<http://arxiv.org/abs/1312.4400>`_. <http://arxiv.org/abs/1312.4400>`_.
...@@ -77,7 +76,9 @@ def GlobalAvgPooling(x): ...@@ -77,7 +76,9 @@ def GlobalAvgPooling(x):
tf.Tensor: a NC tensor named ``output``. tf.Tensor: a NC tensor named ``output``.
""" """
assert x.get_shape().ndims == 4 assert x.get_shape().ndims == 4
return tf.reduce_mean(x, [1, 2], name='output') assert data_format in ['NHWC', 'NCHW']
axis = [1, 2] if data_format == 'NHWC' else [2, 3]
return tf.reduce_mean(x, axis, name='output')
def UnPooling2x2ZeroFilled(x): def UnPooling2x2ZeroFilled(x):
......
...@@ -86,17 +86,22 @@ def shape2d(a): ...@@ -86,17 +86,22 @@ def shape2d(a):
raise RuntimeError("Illegal shape: {}".format(a)) raise RuntimeError("Illegal shape: {}".format(a))
def shape4d(a): def shape4d(a, data_format='NHWC'):
""" """
Ensuer a 4D shape, to use with NHWC functions. Ensuer a 4D shape, to use with 4D symbolic functions.
Args: Args:
a: a int or tuple/list of length 2 a: a int or tuple/list of length 2
Returns: Returns:
list: of length 4. if ``a`` is a int, return ``[1, a, a, 1]``. list: of length 4. if ``a`` is a int, return ``[1, a, a, 1]`` or ``[1,
1, a, a]`` depending on data_format.
""" """
return [1] + shape2d(a) + [1] s2d = shape2d(a)
if data_format == 'NHWC':
return [1] + s2d + [1]
else:
return [1, 1] + s2d
@memoized @memoized
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment