Commit 0ee4d8b0 authored by Yuxin Wu's avatar Yuxin Wu

Add AlexNet script; Memory tracker only at the end of epoch.

parent ef3bceff
Code and model for the paper: Official code and model for the paper:
[DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients](http://arxiv.org/abs/1606.06160), by Zhou et al. + [DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients](http://arxiv.org/abs/1606.06160).
It also contains an implementation of the following papers: It also contains an implementation of the following papers:
+ [Binary Weight Network](https://arxiv.org/abs/1511.00363), with (W,A,G)=(1,32,32). + [Binary Weight Network](https://arxiv.org/abs/1511.00363), with (W,A,G)=(1,32,32).
+ [Trained Ternary Quantization](https://arxiv.org/abs/1612.01064), with (W,A,G)=(t,32,32). + [Trained Ternary Quantization](https://arxiv.org/abs/1612.01064), with (W,A,G)=(t,32,32).
+ [Binarized Neural Networks](https://arxiv.org/abs/1602.02830), with (W,A,G)=(1,1,32). + [Binarized Neural Networks](https://arxiv.org/abs/1602.02830), with (W,A,G)=(1,1,32).
This is a solid baseline for research in model quantization. This is a good set of baselines for research in model quantization.
These quantization techniques achieves the following ImageNet performance in this implementation: These quantization techniques achieves the following ImageNet performance in this implementation:
| Model | W,A,G | Top 1 Error | | Model | W,A,G | Top 1 Validation Error |
|:-------------------|-------------|------------:| |:---------------|----------|-----------------------:|
| Full Precision | 32,32,32 | 40.3% | | Full Precision | 32,32,32 | 40.3% |
| TTQ | t,32,32 | 42.0% | | TTQ | t,32,32 | 42.0% |
| BWN | 1,32,32 | 44.6% | | BWN | 1,32,32 | 44.6% |
...@@ -26,16 +26,16 @@ more sophisticated augmentations. ...@@ -26,16 +26,16 @@ more sophisticated augmentations.
We hosted a demo at CVPR16 on behalf of Megvii, Inc, running a real-time 1/4-VGG size DoReFa-Net on ARM and half-VGG size DoReFa-Net on FPGA. We hosted a demo at CVPR16 on behalf of Megvii, Inc, running a real-time 1/4-VGG size DoReFa-Net on ARM and half-VGG size DoReFa-Net on FPGA.
We're not planning to release our C++ runtime for bit-operations. We're not planning to release our C++ runtime for bit-operations.
In this repo, bit operations are performed through `tf.float32`. In this repo, quantized operations are all performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[tensorpack model zoo](http://models.tensorpack.com/DoReFa-Net/). [tensorpack model zoo](http://models.tensorpack.com/DoReFa-Net/).
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications. They're provided in the format of numpy dictionary.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy. The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
Alternative link to this page: [http://dorefa.net](http://dorefa.net) Alternative link to this page: [http://dorefa.net](http://dorefa.net)
## Preparation: ## Use
+ Install [tensorpack](https://github.com/ppwwyyxx/tensorpack) and scipy. + Install [tensorpack](https://github.com/ppwwyyxx/tensorpack) and scipy.
......
ImageNet training code of ResNet, Inception, VGG, ShuffleNet, DoReFa-Net with tensorpack. ImageNet training code of ResNet, ShuffleNet, DoReFa-Net, AlexNet, Inception, VGG with tensorpack.
To train any of the models, just do `./{model}.py --data /path/to/ilsvrc`. To train any of the models, just do `./{model}.py --data /path/to/ilsvrc`.
Expected format of data directory is described in [docs](http://tensorpack.readthedocs.io/en/latest/modules/dataflow.dataset.html#tensorpack.dataflow.dataset.ILSVRC12). Expected format of data directory is described in [docs](http://tensorpack.readthedocs.io/en/latest/modules/dataflow.dataset.html#tensorpack.dataflow.dataset.ILSVRC12).
...@@ -10,8 +10,10 @@ Pretrained models can be downloaded at [tensorpack model zoo](http://models.tens ...@@ -10,8 +10,10 @@ Pretrained models can be downloaded at [tensorpack model zoo](http://models.tens
Reproduce [ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices](https://arxiv.org/abs/1707.01083) Reproduce [ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices](https://arxiv.org/abs/1707.01083)
on ImageNet. on ImageNet.
This is a 38Mflops ShuffleNet, corresponding to `ShuffleNet 0.5x g=3` in [version 2](https://arxiv.org/pdf/1707.01083v2) of the paper. This is a 38Mflops ShuffleNet, corresponding to `ShuffleNet 0.5x g=3` in __the
After 240 epochs (36 hours on 8 P100s) it reaches top-1 error of 42.32%, better than the paper's number. 2nd arxiv version__ of the paper.
After 240 epochs (36 hours on 8 P100s) it reaches top-1 error of 42.32%,
matching the paper's number.
To print flops: To print flops:
```bash ```bash
...@@ -24,19 +26,35 @@ Evaluate the [pretrained model](http://models.tensorpack.com/ShuffleNet/): ...@@ -24,19 +26,35 @@ Evaluate the [pretrained model](http://models.tensorpack.com/ShuffleNet/):
./shufflenet.py --eval --data /path/to/ilsvrc --load /path/to/model ./shufflenet.py --eval --data /path/to/ilsvrc --load /path/to/model
``` ```
### AlexNet
This AlexNet script is quite close to the setting in its [original
paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks).
Trained with 64x2 batch size, the script reaches 58% single-crop validation
accuracy after 100 epochs. It also generates first-layer filter visualizations
similar to the paper in tensorboard.
### Inception-BN, VGG16 ### Inception-BN, VGG16
This Inception-BN script reaches 27% single-crop error after 300k steps with 6 GPUs. This Inception-BN script reaches 27% single-crop validation error after 300k steps with 6 GPUs.
The training recipe is very different from the original paper because the paper
is a bit vague on these details.
This VGG16 script, when trained with 32x8 batch size, reaches the following This VGG16 script, when trained with 32x8 batch size, reaches the following
error rate after 100 epochs (30h with 8 P100s). This reproduces the VGG validation error after 100 epochs (30h with 8 P100s). This is the code for the VGG
experiments in the paper [Group Normalization](https://arxiv.org/abs/1803.08494). experiments in the paper [Group Normalization](https://arxiv.org/abs/1803.08494).
| No Normalization | Batch Normalization | Group Normalization | | No Normalization | Batch Normalization | Group Normalization |
|:---------------------------------|---------------------|--------------------:| |:------------------------------------------|---------------------|--------------------:|
| 29~30% (varies with random seed) | 28% | 27.6% | | 29~30% (large variation with random seed) | 28% | 27.6% |
### ResNet
See [ResNet examples](../ResNet). It includes variants like pre-activation
ResNet, squeeze-and-excitation networks.
### DoReFa-Net
### ResNet, DoReFa-Net See [DoReFa-Net examples](../DoReFa-Net).
It includes other quantization methods such as Binary Weight Network, Trained Ternary Quantization.
See [ResNet examples](../ResNet) and [DoReFa-Net examples](../DoReFa-Net).
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: alexnet.py
import argparse
import os
import cv2
import numpy as np
import tensorflow as tf
from tensorpack import *
from tensorpack.dataflow import imgaug
from tensorpack.tfutils import argscope
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import ImageNetModel, get_imagenet_dataflow
def visualize_conv1_weights(filters):
ctx = get_current_tower_context()
if not ctx.is_main_training_tower:
return
with tf.name_scope('visualize_conv1'):
filters = tf.reshape(filters, [11, 11, 3, 8, 12])
filters = tf.transpose(filters, [3, 0, 4, 1, 2]) # 8,11,12,11,3
filters = tf.reshape(filters, [1, 88, 132, 3])
tf.summary.image('visualize_conv1', filters, max_outputs=1, collections=['AAA'])
class Model(ImageNetModel):
weight_decay = 5e-4
data_format = 'NHWC' # LRN only supports NHWC
def get_logits(self, image):
gauss_init = tf.random_normal_initializer(stddev=0.01)
with argscope(Conv2D,
kernel_initializer=tf.variance_scaling_initializer(scale=2.)), \
argscope([Conv2D, FullyConnected], activation=tf.nn.relu), \
argscope([Conv2D, MaxPooling], data_format='channels_last'):
# necessary padding to get 55x55 after conv1
image = tf.pad(image, [[0, 0], [2, 2], [2, 2], [0, 0]])
l = Conv2D('conv1', image, filters=96, kernel_size=11, strides=4, padding='VALID')
# size: 55
visualize_conv1_weights(l.variables.W)
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm1')
l = MaxPooling('pool1', l, 3, strides=2, padding='VALID')
# 27
l = Conv2D('conv2', l, filters=256, kernel_size=5, split=2)
l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm2')
l = MaxPooling('pool2', l, 3, strides=2, padding='VALID')
# 13
l = Conv2D('conv3', l, filters=384, kernel_size=3)
l = Conv2D('conv4', l, filters=384, kernel_size=3, split=2)
l = Conv2D('conv5', l, filters=256, kernel_size=3, split=2)
l = MaxPooling('pool3', l, 3, strides=2, padding='VALID')
l = FullyConnected('fc6', l, 4096,
kernel_initializer=gauss_init,
bias_initializer=tf.ones_initializer())
l = Dropout(l, rate=0.5)
l = FullyConnected('fc7', l, 4096, kernel_initializer=gauss_init)
l = Dropout(l, rate=0.5)
logits = FullyConnected('fc8', l, 1000, kernel_initializer=gauss_init)
return logits
def get_data(name, batch):
isTrain = name == 'train'
if isTrain:
augmentors = [
imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
imgaug.RandomCrop(224),
imgaug.Lighting(0.1,
eigval=np.asarray(
[0.2175, 0.0188, 0.0045][::-1]) * 255.0,
eigvec=np.array(
[[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203]],
dtype='float32')[::-1, ::-1]),
imgaug.Flip(horiz=True)]
else:
augmentors = [
imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
imgaug.CenterCrop((224, 224))]
return get_imagenet_dataflow(args.data, name, batch, augmentors)
def get_config():
nr_tower = max(get_nr_gpu(), 1)
batch = args.batch
total_batch = batch * nr_tower
if total_batch != 128:
logger.warn("AlexNet needs to be trained with a total batch size of 128.")
BASE_LR = 0.01 * (total_batch / 128.)
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch)
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
callbacks = [
ModelSaver(),
GPUUtilizationTracker(),
EstimatedTimeLeft(),
ScheduledHyperParamSetter(
'learning_rate',
[(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2), (80, BASE_LR * 1e-3)]),
DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower))),
]
return TrainConfig(
model=Model(),
data=StagingInput(QueueInput(dataset_train)),
callbacks=callbacks,
steps_per_epoch=1281167 // total_batch,
max_epoch=100,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--batch', type=int, default=32, help='batch per GPU')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.set_logger_dir(os.path.join('train_log', 'AlexNet'))
config = get_config()
nr_tower = max(get_num_gpu(), 1)
trainer = SyncMultiGPUTrainerReplicated(nr_tower)
launch_train_with_config(config, trainer)
...@@ -177,7 +177,7 @@ class GraphProfiler(Callback): ...@@ -177,7 +177,7 @@ class GraphProfiler(Callback):
class PeakMemoryTracker(Callback): class PeakMemoryTracker(Callback):
""" """
Track peak memory used on each GPU device, by :mod:`tf.contrib.memory_stats`. Track peak memory used on each GPU device every epoch, by :mod:`tf.contrib.memory_stats`.
The peak memory comes from the `MaxBytesInUse` op, which might span The peak memory comes from the `MaxBytesInUse` op, which might span
multiple session.run. multiple session.run.
See https://github.com/tensorflow/tensorflow/pull/13107. See https://github.com/tensorflow/tensorflow/pull/13107.
...@@ -203,9 +203,12 @@ class PeakMemoryTracker(Callback): ...@@ -203,9 +203,12 @@ class PeakMemoryTracker(Callback):
self._fetches = tf.train.SessionRunArgs(fetches=ops) self._fetches = tf.train.SessionRunArgs(fetches=ops)
def _before_run(self, _): def _before_run(self, _):
if self.local_step == self.trainer.steps_per_epoch - 1:
return self._fetches return self._fetches
return None
def _after_run(self, _, rv): def _after_run(self, _, rv):
results = rv.results results = rv.results
if results is not None:
for mem, dev in zip(results, self._devices): for mem, dev in zip(results, self._devices):
self.trainer.monitors.put_scalar('PeakMemory(MB) ' + dev, mem / 1e6) self.trainer.monitors.put_scalar('PeakMemory(MB) ' + dev, mem / 1e6)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment