Commit d158a073 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent d1a90991
...@@ -7,33 +7,36 @@ It also contains an implementation of the following papers: ...@@ -7,33 +7,36 @@ It also contains an implementation of the following papers:
+ [Trained Ternary Quantization](https://arxiv.org/abs/1612.01064), with (W,A,G)=(t,32,32). + [Trained Ternary Quantization](https://arxiv.org/abs/1612.01064), with (W,A,G)=(t,32,32).
+ [Binarized Neural Networks](https://arxiv.org/abs/1602.02830), with (W,A,G)=(1,1,32). + [Binarized Neural Networks](https://arxiv.org/abs/1602.02830), with (W,A,G)=(1,1,32).
Alternative link to this page: [http://dorefa.net](http://dorefa.net)
## Results:
This is a good set of baselines for research in model quantization. This is a good set of baselines for research in model quantization.
These quantization techniques, when applied on AlexNet, achieves the following ImageNet performance in this implementation: These quantization techniques, when applied on AlexNet, achieves the following ImageNet performance in this implementation:
| Model | Bit Width <br/> (weights, activations, gradients) | Top 1 Validation Error | | Model | Bit Width <br/> (weights, activations, gradients) | Top 1 Validation Error <sup>[1](#ft1)</sup>|
|:---------------|---------------------------------------------------|-----------------------:| |:----------------------------------:|:-------------------------------------------------:|:----------------------:|
| Full Precision | 32,32,32 | 40.3% | | Full Precision<sup>[2](#ft2)</sup> | 32,32,32 | 40.3% |
| TTQ | t,32,32 | 42.0% | | TTQ | t,32,32 | 42.0% |
| BWN | 1,32,32 | 44.6% | | BWN | 1,32,32 | 44.6% |
| BNN | 1,1,32 | 51.9% | | BNN | 1,1,32 | 51.9% |
| DoReFa | 1,2,32 | 46.6% | | DoReFa | 1,2,32 | 46.6% |
| DoReFa | 1,2,6 | 46.8% | | DoReFa | 1,2,6 | 46.8% |
| DoReFa | 1,2,4 | 54.0% | | DoReFa | 1,2,4 | 54.0% |
These numbers were obtained by training on 8 GPUs with a total batch size of 256. <a id="ft1">1</a>: These numbers were obtained by training on 8 GPUs with a total batch size of 256.
The DoReFa-Net models reach slightly better performance than our paper, due to The DoReFa-Net models reach slightly better performance than our paper, due to
more sophisticated augmentations. more sophisticated augmentations.
We hosted a demo at CVPR16 on behalf of Megvii, Inc, running a real-time 1/4-VGG size DoReFa-Net on ARM and half-VGG size DoReFa-Net on FPGA. <a id="ft2">2</a>: Not directly comparable with the original AlexNet. Check out
We're not planning to release our C++ runtime for bit-operations. [../ImageNetModels](../ImageNetModels) for a more faithful implementation of the original AlexNet.
In this repo, quantized operations are all performed through `tf.float32`.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at ## Speed:
[tensorpack model zoo](http://models.tensorpack.com/DoReFa-Net/). __DoReFa-Net works on mobile and FPGA!__
They're provided in the format of numpy dictionary. We hosted a demo at CVPR16 on behalf of Megvii, Inc, running a 1/4-VGG size DoReFa-Net on a phone and a half-VGG size DoReFa-Net on an FPGA, in real time.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy. DoReFa-Net and its variants have been deployed widely in Megvii's embeded products.
Alternative link to this page: [http://dorefa.net](http://dorefa.net) This code release is meant for research purpose. We're not planning to release our C++ runtime for bit-operations.
In this implementation, quantized operations are all performed through `tf.float32`.
## Use ## Use
...@@ -43,6 +46,12 @@ Alternative link to this page: [http://dorefa.net](http://dorefa.net) ...@@ -43,6 +46,12 @@ Alternative link to this page: [http://dorefa.net](http://dorefa.net)
+ Look at the docstring in `*-dorefa.py` to see detailed usage and performance. + Look at the docstring in `*-dorefa.py` to see detailed usage and performance.
Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
[tensorpack model zoo](http://models.tensorpack.com/DoReFa-Net/).
They're provided in the format of numpy dictionary.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
## Support ## Support
Please use [github issues](https://github.com/tensorpack/tensorpack/issues) for any issues related to the code itself. Please use [github issues](https://github.com/tensorpack/tensorpack/issues) for any issues related to the code itself.
......
...@@ -46,6 +46,7 @@ To Train: ...@@ -46,6 +46,7 @@ To Train:
Fast disk random access (Not necessarily SSD. I used a RAID of HDD, but not sure if plain HDD is enough) Fast disk random access (Not necessarily SSD. I used a RAID of HDD, but not sure if plain HDD is enough)
More than 20 CPU cores (for data processing) More than 20 CPU cores (for data processing)
More than 10G of free memory More than 10G of free memory
On 8 P100s and dorefa==1,2,6, the training should take about 30 minutes per epoch.
To run pretrained model: To run pretrained model:
./alexnet-dorefa.py --load alexnet-126.npz --run a.jpg --dorefa 1,2,6 ./alexnet-dorefa.py --load alexnet-126.npz --run a.jpg --dorefa 1,2,6
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
# Author: Yuxin Wu # Author: Yuxin Wu
import argparse import argparse
import tensorflow as tf
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils.varreplace import remap_variables from tensorpack.tfutils.varreplace import remap_variables
import tensorflow as tf
from dorefa import get_dorefa from dorefa import get_dorefa
...@@ -19,7 +19,7 @@ DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidt ...@@ -19,7 +19,7 @@ DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidt
http://arxiv.org/abs/1606.06160 http://arxiv.org/abs/1606.06160
The original experiements are performed on a proprietary framework. The original experiements are performed on a proprietary framework.
This is our attempt to reproduce it on tensorpack/tensorflow. This is our attempt to reproduce it on tensorpack.
Accuracy: Accuracy:
With (W,A,G)=(1,1,4), can reach 3.1~3.2% error after 150 epochs. With (W,A,G)=(1,1,4), can reach 3.1~3.2% error after 150 epochs.
......
...@@ -93,6 +93,7 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True): ...@@ -93,6 +93,7 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
n,C,size,size n,C,size,size
""" """
assert isinstance(crop_size, int), crop_size assert isinstance(crop_size, int), crop_size
boxes = tf.stop_gradient(boxes)
# TF's crop_and_resize produces zeros on border # TF's crop_and_resize produces zeros on border
if pad_border: if pad_border:
...@@ -162,7 +163,6 @@ def roi_align(featuremap, boxes, resolution): ...@@ -162,7 +163,6 @@ def roi_align(featuremap, boxes, resolution):
Returns: Returns:
NxCx res x res NxCx res x res
""" """
boxes = tf.stop_gradient(boxes) # TODO
# sample 4 locations per roi bin # sample 4 locations per roi bin
ret = crop_and_resize( ret = crop_and_resize(
featuremap, boxes, featuremap, boxes,
......
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
# See https://pep8.readthedocs.io/en/latest/intro.html#error-codes
ignore = F403,F405,E402,E741,E742,E743 ignore = F403,F405,E402,E741,E742,E743
exclude = private, exclude = private,
FasterRCNN/utils FasterRCNN/utils
...@@ -4,7 +4,7 @@ from os import path ...@@ -4,7 +4,7 @@ from os import path
import platform import platform
version = int(setuptools.__version__.split('.')[0]) version = int(setuptools.__version__.split('.')[0])
assert version > 30, "tensorpack installation requires setuptools > 30" assert version > 30, "Tensorpack installation requires setuptools > 30"
this_directory = path.abspath(path.dirname(__file__)) this_directory = path.abspath(path.dirname(__file__))
......
...@@ -73,7 +73,8 @@ class ModelSaver(Callback): ...@@ -73,7 +73,8 @@ class ModelSaver(Callback):
self.saver.export_meta_graph( self.saver.export_meta_graph(
os.path.join(self.checkpoint_dir, os.path.join(self.checkpoint_dir,
'graph-{}.meta'.format(time)), 'graph-{}.meta'.format(time)),
collection_list=self.graph.get_all_collection_keys()) collection_list=self.graph.get_all_collection_keys(),
clear_extraneous_savers=True)
def _trigger(self): def _trigger(self):
try: try:
......
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
# See https://pep8.readthedocs.io/en/latest/intro.html#error-codes
ignore = E265,E741,E742,E743 ignore = E265,E741,E742,E743
exclude = .git, exclude = .git,
__init__.py, __init__.py,
setup.py, setup.py,
tensorpack/train/eager.py, tensorpack/train/eager.py,
docs, docs,
examples, examples,
docs/conf.py docs/conf.py
snippet, snippet,
examples-old, examples-old,
_test.py, _test.py,
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment