Commit b5acbf3a authored by Yuxin Wu's avatar Yuxin Wu

dorefa resnet

parent 5adf1f73
...@@ -6,9 +6,12 @@ We hosted a demo at CVPR16 on behalf of Megvii, Inc, running a real-time 1/4-VGG ...@@ -6,9 +6,12 @@ We hosted a demo at CVPR16 on behalf of Megvii, Inc, running a real-time 1/4-VGG
We're not planning to release our C++ runtime for bit-operations. We're not planning to release our C++ runtime for bit-operations.
In this repo, bit operations are performed through `tf.float32`. In this repo, bit operations are performed through `tf.float32`.
Pretrained model for 1-2-6-AlexNet is available at Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[google drive](https://drive.google.com/a/%20megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ). [google drive](https://drive.google.com/a/megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ).
It's provided in the format of numpy dictionary, so it should be very easy to port into other applications. They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The binary-weight 4-bit-activation ResNet-18 model has 59.2% top-1 validation error.
Alternative link to this page: [http://dorefa.net](http://dorefa.net)
## Preparation: ## Preparation:
...@@ -27,7 +30,7 @@ pip install --user scipy ...@@ -27,7 +30,7 @@ pip install --user scipy
export PYTHONPATH=$PYTHONPATH:`readlink -f tensorpack` export PYTHONPATH=$PYTHONPATH:`readlink -f tensorpack`
``` ```
+ Look at the docstring in `svhn-digit-dorefa.py` or `alexnet-dorefa.py` to see detailed usage and performance. + Look at the docstring in `*-dorefa.py` to see detailed usage and performance.
## Support ## Support
......
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: resnet-dorefa.py
import tensorflow as tf
import argparse
import numpy as np
import cv2
import os
import sys
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.varreplace import replace_get_variable
from dorefa import get_dorefa
"""
This script loads the pre-trained ResNet-18 model with (W,A,G) = (1,4,32)
It has 59.2% top-1 and 81.5% top-5 validation error on ILSVRC12 validation set.
To run on images:
./resnet-dorefa.py --load pretrained.npy --run a.jpg b.jpg
To eval on ILSVRC validation set:
./resnet-dorefa.py --load pretrained.npy --eval --data /path/to/ILSVRC
"""
BITW = 1
BITA = 4
BITG = 32
class Model(ModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, [None, 224, 224, 3], 'input'),
InputVar(tf.int32, [None], 'label')]
def _build_graph(self, input_vars):
image, label = input_vars
image = image / 256.0
fw, fa, fg = get_dorefa(BITW, BITA, BITG)
old_get_variable = tf.get_variable
def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
# don't binarize first and last layer
if name != 'W' or 'conv1' in v.op.name or 'fct' in v.op.name:
return v
else:
logger.info("Binarizing weight {}".format(v.op.name))
return fw(v)
def nonlin(x):
return tf.clip_by_value(x, 0.0, 1.0)
def activate(x):
return fa(nonlin(x))
def resblock(x, channel, stride):
def get_stem_full(x):
return (LinearWrap(x)
.Conv2D('c3x3a', channel, 3)
.BatchNorm('stembn')
.apply(activate)
.Conv2D('c3x3b', channel, 3)())
channel_mismatch = channel != x.get_shape().as_list()[3]
if stride != 1 or channel_mismatch or 'pool1' in x.name:
# handling pool1 is to work around an architecture bug in our model
if stride != 1 or 'pool1' in x.name:
x = AvgPooling('pool', x, stride, stride)
x = BatchNorm('bn', x)
x = activate(x)
shortcut = Conv2D('shortcut', x, channel, 1)
stem = get_stem_full(x)
else:
shortcut = x
x = BatchNorm('bn', x)
x = activate(x)
stem = get_stem_full(x)
return shortcut + stem
def group(x, name, channel, nr_block, stride):
with tf.variable_scope(name + 'blk1'):
x = resblock(x, channel, stride)
for i in range(2, nr_block + 1):
with tf.variable_scope(name + 'blk{}'.format(i)):
x = resblock(x, channel, 1)
return x
with replace_get_variable(new_get_variable), \
argscope(BatchNorm, decay=0.9, epsilon=1e-4), \
argscope(Conv2D, use_bias=False, nl=tf.identity):
logits = (LinearWrap(image)
# use explicit padding here, because our training framework has
# different padding mechanisms from TensorFlow
.tf.pad([[0, 0], [3, 2], [3, 2], [0, 0]])
.Conv2D('conv1', 64, 7, stride=2, padding='VALID', use_bias=True)
.tf.pad([[0, 0], [1, 1], [1, 1], [0, 0]], 'SYMMETRIC')
.MaxPooling('pool1', 3, 2, padding='VALID')
.apply(group, 'conv2', 64, 2, 1)
.apply(group, 'conv3', 128, 2, 2)
.apply(group, 'conv4', 256, 2, 2)
.apply(group, 'conv5', 512, 2, 2)
.BatchNorm('lastbn')
.apply(nonlin)
.GlobalAvgPooling('gap')
.tf.mul(49) # this is due to a bug in our model design
.FullyConnected('fct', 1000)())
prob = tf.nn.softmax(logits, name='output')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
def get_inference_augmentor():
return imgaug.AugmentorList([
imgaug.ResizeShortestEdge(256),
imgaug.CenterCrop(224),
])
def run_image(model, sess_init, inputs):
pred_config = PredictConfig(
model=model,
session_init=sess_init,
session_config=get_default_sess_config(0.9),
input_names=['input'],
output_names=['output']
)
predict_func = get_predict_func(pred_config)
meta = dataset.ILSVRCMeta()
words = meta.get_synset_words_1000()
transformers = get_inference_augmentor()
for f in inputs:
assert os.path.isfile(f)
img = cv2.imread(f).astype('float32')
assert img is not None
img = transformers.augment(img)[np.newaxis, :, :, :]
o = predict_func([img])
prob = o[0][0]
ret = prob.argsort()[-10:][::-1]
names = [words[i] for i in ret]
print(f + ":")
print(list(zip(names, prob[ret])))
def eval_on_ILSVRC12(model_path, data_dir):
ds = dataset.ILSVRC12(data_dir, 'val', shuffle=False)
ds = AugmentImageComponent(ds, get_inference_augmentor())
ds = BatchData(ds, 192, remainder=True)
pred_config = PredictConfig(
model=Model(),
session_init=get_model_loader(model_path),
input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
)
pred = SimpleDatasetPredictor(pred_config, ds)
acc1, acc5 = RatioCounter(), RatioCounter()
for o in pred.get_result():
batch_size = o[0].shape[0]
acc1.feed(o[0].sum(), batch_size)
acc5.feed(o[1].sum(), batch_size)
print("Top1 Error: {}".format(acc1.ratio))
print("Top5 Error: {}".format(acc5.ratio))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='the physical ids of GPUs to use')
parser.add_argument('--load', help='load a npy pretrained model')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--dorefa',
help='number of bits for W,A,G, separated by comma. Defaults to \'1,4,32\'',
default='1,4,32')
parser.add_argument(
'--run', help='run on a list of images with the pretrained model', nargs='*')
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
BITW, BITA, BITG = map(int, args.dorefa.split(','))
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.eval:
eval_on_ILSVRC12(args.load, args.data)
elif args.run:
assert args.load.endswith('.npy')
run_image(Model(), ParamRestore(
np.load(args.load, encoding='latin1').item()), args.run)
...@@ -101,7 +101,7 @@ class ILSVRC12(RNGDataFlow): ...@@ -101,7 +101,7 @@ class ILSVRC12(RNGDataFlow):
Produces ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999], Produces ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999],
and optionally a bounding box of [xmin, ymin, xmax, ymax]. and optionally a bounding box of [xmin, ymin, xmax, ymax].
""" """
def __init__(self, dir, name, meta_dir=None, shuffle=True, def __init__(self, dir, name, meta_dir=None, shuffle=None,
dir_structure='original', include_bb=False): dir_structure='original', include_bb=False):
""" """
Args: Args:
...@@ -109,6 +109,7 @@ class ILSVRC12(RNGDataFlow): ...@@ -109,6 +109,7 @@ class ILSVRC12(RNGDataFlow):
original ``ILSVRC12_img_{name}.tar`` gets decompressed. original ``ILSVRC12_img_{name}.tar`` gets decompressed.
name (str): 'train' or 'val' or 'test'. name (str): 'train' or 'val' or 'test'.
shuffle (bool): shuffle the dataset. shuffle (bool): shuffle the dataset.
Defaults to True if name=='train'.
dir_structure (str): The dir structure of 'val' and 'test' directory. dir_structure (str): The dir structure of 'val' and 'test' directory.
If is 'original', it expects the original decompressed If is 'original', it expects the original decompressed
directory, which only has list of image files (as below). directory, which only has list of image files (as below).
...@@ -149,6 +150,8 @@ class ILSVRC12(RNGDataFlow): ...@@ -149,6 +150,8 @@ class ILSVRC12(RNGDataFlow):
self.full_dir = os.path.join(dir, name) self.full_dir = os.path.join(dir, name)
self.name = name self.name = name
assert os.path.isdir(self.full_dir), self.full_dir assert os.path.isdir(self.full_dir), self.full_dir
if shuffle is None:
shuffle = name == 'train'
self.shuffle = shuffle self.shuffle = shuffle
meta = ILSVRCMeta(meta_dir) meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name) self.imglist = meta.get_image_list(name)
......
...@@ -33,7 +33,9 @@ def MaxPooling(x, shape, stride=None, padding='VALID'): ...@@ -33,7 +33,9 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
else: else:
stride = shape4d(stride) stride = shape4d(stride)
return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding) return tf.nn.max_pool(x, ksize=shape,
strides=stride, padding=padding,
name='output')
@layer_register() @layer_register()
...@@ -54,7 +56,8 @@ def AvgPooling(x, shape, stride=None, padding='VALID'): ...@@ -54,7 +56,8 @@ def AvgPooling(x, shape, stride=None, padding='VALID'):
else: else:
stride = shape4d(stride) stride = shape4d(stride)
return tf.nn.avg_pool(x, ksize=shape, strides=stride, padding=padding) return tf.nn.avg_pool(x, ksize=shape,
strides=stride, padding=padding, name='output')
@layer_register() @layer_register()
...@@ -69,7 +72,7 @@ def GlobalAvgPooling(x): ...@@ -69,7 +72,7 @@ def GlobalAvgPooling(x):
tf.Tensor: a NC tensor. tf.Tensor: a NC tensor.
""" """
assert x.get_shape().ndims == 4 assert x.get_shape().ndims == 4
return tf.reduce_mean(x, [1, 2]) return tf.reduce_mean(x, [1, 2], name='output')
def UnPooling2x2ZeroFilled(x): def UnPooling2x2ZeroFilled(x):
......
...@@ -28,6 +28,12 @@ def freeze_get_variable(): ...@@ -28,6 +28,12 @@ def freeze_get_variable():
""" """
Return a contextmanager, where all variables returned by Return a contextmanager, where all variables returned by
`get_variable` will have no gradients. `get_variable` will have no gradients.
Example:
.. code-block:: python
with varreplace.freeze_get_variable():
x = FullyConnected('fc', x, 1000) # fc/* will not be trained
""" """
old_get_variable = tf.get_variable old_get_variable = tf.get_variable
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment