Commit 53903072 authored by Yuxin Wu's avatar Yuxin Wu

Add groupnorm-vgg experiment

parent 58d68447
...@@ -13,7 +13,7 @@ It's Yet Another TF wrapper, but different in: ...@@ -13,7 +13,7 @@ It's Yet Another TF wrapper, but different in:
1. Focus on __training speed__. 1. Focus on __training speed__.
+ Speed comes for free with tensorpack -- it uses TensorFlow in the __efficient way__ with no extra overhead. + Speed comes for free with tensorpack -- it uses TensorFlow in the __efficient way__ with no extra overhead.
On different CNNs, it runs [1.2~5x faster](https://github.com/tensorpack/benchmarks/tree/master/other-wrappers) than the equivalent Keras code. On different CNNs, it runs training [1.2~5x faster](https://github.com/tensorpack/benchmarks/tree/master/other-wrappers) than the equivalent Keras code.
+ Data-parallel multi-GPU training is off-the-shelf to use. It scales as well as Google's [official benchmark](https://www.tensorflow.org/performance/benchmarks). + Data-parallel multi-GPU training is off-the-shelf to use. It scales as well as Google's [official benchmark](https://www.tensorflow.org/performance/benchmarks).
...@@ -32,7 +32,7 @@ See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html) ...@@ -32,7 +32,7 @@ See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html)
## [Examples](examples): ## [Examples](examples):
Instead of showing you 10 random networks with random accuracy, Instead of showing you 10 random networks trained on toy datasets,
[tensorpack examples](examples) faithfully replicate papers and care about performance. [tensorpack examples](examples) faithfully replicate papers and care about performance.
And everything runs on multiple GPUs. Some highlights: And everything runs on multiple GPUs. Some highlights:
......
...@@ -28,7 +28,8 @@ Evaluate the [pretrained model](http://models.tensorpack.com/ShuffleNet/): ...@@ -28,7 +28,8 @@ Evaluate the [pretrained model](http://models.tensorpack.com/ShuffleNet/):
This Inception-BN script reaches 27% single-crop error after 300k steps with 6 GPUs. This Inception-BN script reaches 27% single-crop error after 300k steps with 6 GPUs.
This VGG16 script reaches 29~30% single-crop error after 100 epochs (30h with 8 P100s), and 28% if BN is enabled. This VGG16 script, when trained with 32x8 batch size, reaches 29~30% single-crop error after 100 epochs (30h with 8 P100s),
28% with BN, and 27.6% with GN.
### ResNet, DoReFa-Net ### ResNet, DoReFa-Net
......
...@@ -16,10 +16,43 @@ from imagenet_utils import ( ...@@ -16,10 +16,43 @@ from imagenet_utils import (
ImageNetModel, get_imagenet_dataflow, fbresnet_augmentor) ImageNetModel, get_imagenet_dataflow, fbresnet_augmentor)
def GroupNorm(x, group):
"""
https://arxiv.org/abs/1803.08494
"""
shape = x.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
chan = shape[1]
assert chan % group == 0, chan
group_size = chan // group
orig_shape = tf.shape(x)
h, w = orig_shape[2], orig_shape[3]
x = tf.reshape(x, tf.stack([-1, group, group_size, h, w]))
mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
new_shape = [1, group, group_size, 1, 1]
beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
gamma = tf.get_variable('gamma', [chan], initializer=tf.constant_initializer(1.0))
gamma = tf.reshape(gamma, new_shape)
out = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5, name='output')
return tf.reshape(out, orig_shape, name='output')
def convnormrelu(x, name, chan): def convnormrelu(x, name, chan):
x = Conv2D(name, x, chan, 3) x = Conv2D(name, x, chan, 3)
if args.norm == 'bn': if args.norm == 'bn':
x = BatchNorm(name + '_bn', x) x = BatchNorm(name + '_bn', x)
elif args.norm == 'gn':
with tf.variable_scope(name + '_gn'):
x = GroupNorm(x, 32)
x = tf.nn.relu(x, name=name + '_relu') x = tf.nn.relu(x, name=name + '_relu')
return x return x
...@@ -78,7 +111,7 @@ def get_data(name, batch): ...@@ -78,7 +111,7 @@ def get_data(name, batch):
def get_config(): def get_config():
nr_tower = max(get_nr_gpu(), 1) nr_tower = max(get_nr_gpu(), 1)
batch = 64 batch = args.batch
total_batch = batch * nr_tower total_batch = batch * nr_tower
BASE_LR = 0.01 * (total_batch / 256.) BASE_LR = 0.01 * (total_batch / 256.)
...@@ -117,7 +150,8 @@ if __name__ == '__main__': ...@@ -117,7 +150,8 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--norm', choices=['none', 'bn'], default='none') parser.add_argument('--batch', type=int, default=32, help='batch per GPU')
parser.add_argument('--norm', choices=['none', 'bn', 'gn'], default='none')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
......
...@@ -4,9 +4,14 @@ ...@@ -4,9 +4,14 @@
Training examples with __reproducible performance__. Training examples with __reproducible performance__.
__The word "reproduce" should always means reproduce performance__. __The word "reproduce" should always means reproduce performance__.
With the magic of SGD, wrong code often appears to still work, unless you check its performance number. With the magic of SGD, wrong deep learning code often appears to still work,
especially if you try it on toy datasets.
See [Unawareness of Deep Learning Mistakes](https://medium.com/@ppwwyyxx/unawareness-of-deep-learning-mistakes-d5b5774da0ba). See [Unawareness of Deep Learning Mistakes](https://medium.com/@ppwwyyxx/unawareness-of-deep-learning-mistakes-d5b5774da0ba).
Instead of showing you 10 arbitrary networks trained on toy datasets with random final performance,
tensorpack examples try to faithfully replicate experiments and performance in the paper as much as possible,
so you're confident that they are correct.
## Getting Started: ## Getting Started:
These examples don't have meaningful performance numbers. They are supposed to be just demos. These examples don't have meaningful performance numbers. They are supposed to be just demos.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment