Commit 08cecd44 authored by Yuxin Wu's avatar Yuxin Wu

support NCHW in deconv. change some more examples. (#150)

parent 880b767d
...@@ -40,9 +40,8 @@ Describe your training task with three components: ...@@ -40,9 +40,8 @@ Describe your training task with three components:
2. __DataFlow__. tensorpack allows and encourages complex data processing. 2. __DataFlow__. tensorpack allows and encourages complex data processing.
+ All data producer has an unified interface, allowing them to be composed to perform complex preprocessing. + All data producer has an unified interface, so they can be composed and reused to perform complex preprocessing.
+ Use Python to easily handle any data format, yet still keep good performance thanks to multiprocess prefetch & TF Queue prefetch. + Allows you to process data from Python without blocking the training, thanks to multiprocess prefetch & TF Queue prefetch.
For example, InceptionV3 can run in the same speed as the official code which reads data by TF operators.
3. __Callbacks__, including everything you want to do apart from the training iterations, such as: 3. __Callbacks__, including everything you want to do apart from the training iterations, such as:
+ Change hyperparameters during training + Change hyperparameters during training
......
...@@ -41,7 +41,9 @@ Accuracy: ...@@ -41,7 +41,9 @@ Accuracy:
With (W,A,G)=(1,2,4), 63% error. With (W,A,G)=(1,2,4), 63% error.
Speed: Speed:
About 2.8 iteration/s on 1 TitanX. (Each epoch is set to 10000 iterations) About 2.2 iteration/s on 1 TitanX. (Each epoch is set to 10000 iterations)
Note that this code was written early without using NCHW format. You
should expect a 30% speed up after switching to NCHW format.
To Train, for example: To Train, for example:
./alexnet-dorefa.py --dorefa 1,2,6 --data PATH --gpu 0,1 ./alexnet-dorefa.py --dorefa 1,2,6 --data PATH --gpu 0,1
...@@ -176,7 +178,6 @@ def get_data(dataset_name): ...@@ -176,7 +178,6 @@ def get_data(dataset_name):
if isTrain: if isTrain:
class Resize(imgaug.ImageAugmentor): class Resize(imgaug.ImageAugmentor):
def __init__(self): def __init__(self):
self._init(locals()) self._init(locals())
......
...@@ -71,10 +71,9 @@ class Model(ModelDesc): ...@@ -71,10 +71,9 @@ class Model(ModelDesc):
l = c2 + l l = c2 + l
return l return l
with argscope([Conv2D, AvgPooling, BatchNorm, GlobalAvgPooling], with argscope([Conv2D, AvgPooling, BatchNorm, GlobalAvgPooling], data_format='NCHW'), \
data_format='NCHW'), \ argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3,
argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3, W_init=variance_scaling_initializer(mode='FAN_OUT')):
W_init=variance_scaling_initializer(mode='FAN_OUT')):
l = Conv2D('conv0', image, 16, nl=BNReLU) l = Conv2D('conv0', image, 16, nl=BNReLU)
l = residual('res1.0', l, first=True) l = residual('res1.0', l, first=True)
for k in range(1, self.n): for k in range(1, self.n):
......
...@@ -34,6 +34,7 @@ class Model(ModelDesc): ...@@ -34,6 +34,7 @@ class Model(ModelDesc):
image_mean = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32) image_mean = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
image_std = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32) image_std = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)
image = (image - image_mean) / image_std image = (image - image_mean) / image_std
image = tf.transpose(image, [0, 3, 1, 2])
def shortcut(l, n_in, n_out, stride): def shortcut(l, n_in, n_out, stride):
if n_in != n_out: if n_in != n_out:
...@@ -42,7 +43,7 @@ class Model(ModelDesc): ...@@ -42,7 +43,7 @@ class Model(ModelDesc):
return l return l
def basicblock(l, ch_out, stride, preact): def basicblock(l, ch_out, stride, preact):
ch_in = l.get_shape().as_list()[-1] ch_in = l.get_shape().as_list()[1]
if preact == 'both_preact': if preact == 'both_preact':
l = BNReLU('preact', l) l = BNReLU('preact', l)
input = l input = l
...@@ -56,7 +57,7 @@ class Model(ModelDesc): ...@@ -56,7 +57,7 @@ class Model(ModelDesc):
return l + shortcut(input, ch_in, ch_out, stride) return l + shortcut(input, ch_in, ch_out, stride)
def bottleneck(l, ch_out, stride, preact): def bottleneck(l, ch_out, stride, preact):
ch_in = l.get_shape().as_list()[-1] ch_in = l.get_shape().as_list()[1]
if preact == 'both_preact': if preact == 'both_preact':
l = BNReLU('preact', l) l = BNReLU('preact', l)
input = l input = l
...@@ -89,7 +90,8 @@ class Model(ModelDesc): ...@@ -89,7 +90,8 @@ class Model(ModelDesc):
defs, block_func = cfg[DEPTH] defs, block_func = cfg[DEPTH]
with argscope(Conv2D, nl=tf.identity, use_bias=False, with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')): W_init=variance_scaling_initializer(mode='FAN_OUT')), \
argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU) .Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME') .MaxPooling('pool0', shape=3, stride=2, padding='SAME')
...@@ -120,6 +122,7 @@ class Model(ModelDesc): ...@@ -120,6 +122,7 @@ class Model(ModelDesc):
def get_data(train_or_test): def get_data(train_or_test):
# return FakeData([[64, 224,224,3],[64]], 1000, random=False, dtype='uint8')
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
datadir = args.data datadir = args.data
......
...@@ -17,7 +17,7 @@ A small convnet model for Cifar10 or Cifar100 dataset. ...@@ -17,7 +17,7 @@ A small convnet model for Cifar10 or Cifar100 dataset.
Cifar10: Cifar10:
91% accuracy after 50k step. 91% accuracy after 50k step.
30 step/s on TitanX 41 step/s on TitanX
Not a good model for Cifar100, just for demonstration. Not a good model for Cifar100, just for demonstration.
""" """
...@@ -40,9 +40,11 @@ class Model(ModelDesc): ...@@ -40,9 +40,11 @@ class Model(ModelDesc):
if is_training: if is_training:
tf.summary.image("train_image", image, 10) tf.summary.image("train_image", image, 10)
image = tf.transpose(image, [0, 3, 1, 2])
image = image / 4.0 # just to make range smaller image = image / 4.0 # just to make range smaller
with argscope(Conv2D, nl=BNReLU, use_bias=False, kernel_shape=3): with argscope(Conv2D, nl=BNReLU, use_bias=False, kernel_shape=3), \
argscope([Conv2D, MaxPooling, BatchNorm], data_format='NCHW'):
logits = LinearWrap(image) \ logits = LinearWrap(image) \
.Conv2D('conv1.1', out_channel=64) \ .Conv2D('conv1.1', out_channel=64) \
.Conv2D('conv1.2', out_channel=64) \ .Conv2D('conv1.2', out_channel=64) \
...@@ -101,7 +103,7 @@ def get_data(train_or_test, cifar_classnum): ...@@ -101,7 +103,7 @@ def get_data(train_or_test, cifar_classnum):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain) ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, 3) ds = PrefetchDataZMQ(ds, 5)
return ds return ds
......
...@@ -192,7 +192,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -192,7 +192,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
assert not ctx.is_training, "In training, local statistics has to be used!" assert not ctx.is_training, "In training, local statistics has to be used!"
if data_format == 'NCHW': if data_format == 'NCHW':
# fused is slower in inference, but support NCHW # fused is slower in inference, but support NCHW
xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta, xn, _, _ = tf.nn.fused_batch_norm(
x, gamma, beta,
moving_mean, moving_var, moving_mean, moving_var,
epsilon=epsilon, is_training=False, data_format=data_format) epsilon=epsilon, is_training=False, data_format=data_format)
else: else:
......
...@@ -89,7 +89,8 @@ class StaticDynamicShape(object): ...@@ -89,7 +89,8 @@ class StaticDynamicShape(object):
def Deconv2D(x, out_shape, kernel_shape, def Deconv2D(x, out_shape, kernel_shape,
stride, padding='SAME', stride, padding='SAME',
W_init=None, b_init=None, W_init=None, b_init=None,
nl=tf.identity, use_bias=True): nl=tf.identity, use_bias=True,
data_format='NHWC'):
""" """
2D deconvolution on 4D inputs. 2D deconvolution on 4D inputs.
...@@ -114,25 +115,33 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -114,25 +115,33 @@ def Deconv2D(x, out_shape, kernel_shape,
* ``W``: weights * ``W``: weights
* ``b``: bias * ``b``: bias
""" """
in_shape = x.get_shape().as_list()[1:] in_shape = x.get_shape().as_list()
in_channel = in_shape[-1] channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis]
assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!" assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"
kernel_shape = shape2d(kernel_shape) kernel_shape = shape2d(kernel_shape)
stride2d = shape2d(stride) stride2d = shape2d(stride)
stride4d = shape4d(stride) stride4d = shape4d(stride, data_format=data_format)
padding = padding.upper() padding = padding.upper()
in_shape_dyn = tf.shape(x)
if isinstance(out_shape, int): if isinstance(out_shape, int):
out_channel = out_shape out_channel = out_shape
shp3_0 = StaticDynamicShape(in_shape[0], tf.shape(x)[1]).apply(lambda x: stride2d[0] * x) if data_format == 'NHWC':
shp3_1 = StaticDynamicShape(in_shape[1], tf.shape(x)[2]).apply(lambda x: stride2d[1] * x) shp3_0 = StaticDynamicShape(in_shape[1], in_shape_dyn[1]).apply(lambda x: stride2d[0] * x)
shp3_dyn = [shp3_0.dynamic, shp3_1.dynamic, out_channel] shp3_1 = StaticDynamicShape(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[1] * x)
shp3_static = [shp3_0.static, shp3_1.static, out_channel] shp3_dyn = [shp3_0.dynamic, shp3_1.dynamic, out_channel]
shp3_static = [shp3_0.static, shp3_1.static, out_channel]
else:
shp3_0 = StaticDynamicShape(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicShape(in_shape[3], in_shape_dyn[3]).apply(lambda x: stride2d[1] * x)
shp3_dyn = [out_channel, shp3_0.dynamic, shp3_1.dynamic]
shp3_static = [out_channel, shp3_0.static, shp3_1.static]
else: else:
for k in out_shape: for k in out_shape:
if not isinstance(k, int): if not isinstance(k, int):
raise ValueError("[Deconv2D] out_shape is invalid!") raise ValueError("[Deconv2D] out_shape {} is invalid!".format(k))
out_channel = out_shape[-1] out_channel = out_shape[channel_axis]
shp3_static = shp3_dyn = out_shape shp3_static = shp3_dyn = out_shape
filter_shape = kernel_shape + [out_channel, in_channel] filter_shape = kernel_shape + [out_channel, in_channel]
...@@ -145,6 +154,7 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -145,6 +154,7 @@ def Deconv2D(x, out_shape, kernel_shape,
b = tf.get_variable('b', [out_channel], initializer=b_init) b = tf.get_variable('b', [out_channel], initializer=b_init)
out_shape_dyn = tf.stack([tf.shape(x)[0]] + shp3_dyn) out_shape_dyn = tf.stack([tf.shape(x)[0]] + shp3_dyn)
conv = tf.nn.conv2d_transpose(x, W, out_shape_dyn, stride4d, padding=padding) conv = tf.nn.conv2d_transpose(
x, W, out_shape_dyn, stride4d, padding=padding, data_format=data_format)
conv.set_shape(tf.TensorShape([None] + shp3_static)) conv.set_shape(tf.TensorShape([None] + shp3_static))
return nl(tf.nn.bias_add(conv, b) if use_bias else conv, name='output') return nl(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment