Commit 62332e0f authored by Yuxin Wu's avatar Yuxin Wu

update readme

parent 76e216eb
......@@ -5,6 +5,7 @@ In development but usable. API might change a bit.
See some interesting [examples](https://github.com/ppwwyyxx/tensorpack/tree/master/examples) to learn about the framework:
+ [DoReFa-Net: low bitwidth CNN](https://github.com/ppwwyyxx/tensorpack/tree/master/examples/DoReFa-Net)
+ [Double-DQN for playing Atari games](https://github.com/ppwwyyxx/tensorpack/tree/master/examples/Atari2600)
+ [ResNet for Cifar10 classification](https://github.com/ppwwyyxx/tensorpack/tree/master/examples/ResNet)
+ [char-rnn language model](https://github.com/ppwwyyxx/tensorpack/tree/master/examples/char-rnn)
......
......@@ -2,9 +2,7 @@ This is the official script to load and run pretrained model for the paper:
[DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients](http://arxiv.org/abs/1606.06160), by Zhou et al.
(Work in Progress. More instructions to come soon)
This is an AlexNet model with 1 bit weights, 2 bit activations, trained with 4 bit gradients.
The provided model is an AlexNet with 1 bit weights, 2 bit activations, trained with 4 bit gradients.
## Preparation:
......@@ -12,25 +10,26 @@ To use the script. You'll need:
+ [TensorFlow](tensorflow.org) >= 0.8
+ [tensorpack](https://github.com/ppwwyyxx/tensorpack) and pyzmq:
+ [tensorpack](https://github.com/ppwwyyxx/tensorpack):
```
git clone https://github.com/ppwwyyxx/tensorpack
pip install --user -r tensorpack/requirements.txt
pip install --user pyzmq
export PYTHONPATH=$PYTHONPATH:`readlink -f tensorpack`
```
+ Download the model at [google drive](https://drive.google.com/drive/u/2/folders/0B308TeQzmFDLa0xOeVQwcXg1ZjQ)
## Load and run the model
We provide two format for the model:
We publish the model in two file formats:
1. alexnet.npy. It's simply a numpy dict of {param name: value}. To load:
1. `alexnet.npy`. It's simply a numpy dict of {param name: value}. Use it with:
```
./alexnet.py --load alexnet.npy [--input img.jpg] [--data path/to/data]
```
2. alexnet.meta + alexnet.tfmodel. A TensorFlow MetaGraph proto and a saved checkpoint.
2. `alexnet.meta` + `alexnet.tfmodel`. A TensorFlow MetaGraph proto and a saved checkpoint.
```
./alexnet.py --graph alexnet.meta --load alexnet.tfmodel [--input path/to/img.jpg] [--data path/to/ILSVRC12]
......@@ -38,3 +37,8 @@ We provide two format for the model:
One of `--data` or `--input` must be present, to either run classification on some input images, or run evaluation on ILSVRC12 validation set.
To eval on ILSVRC12, `path/to/ILSVRC12` must have a subdirectory named 'val' containing all the validation images.
## Support
Please use [github issues](https://github.com/ppwwyyxx/tensorpack/issues) for any issues related to the code.
Send email to the authors for other questions related to the paper.
......@@ -8,15 +8,6 @@ import argparse
import numpy as np
import os
"""
Run the pretrained model of paper:
DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
http://arxiv.org/abs/1606.06160
Model can be downloaded at:
https://drive.google.com/drive/u/2/folders/0B308TeQzmFDLa0xOeVQwcXg1ZjQ
"""
from tensorpack import *
from tensorpack.utils.stat import RatioCounter
from tensorpack.tfutils.symbolic_functions import prediction_incorrect
......@@ -76,7 +67,7 @@ def eval_on_ILSVRC12(model, sess_init, data_dir):
]
ds = AugmentImageComponent(ds, transformers)
ds = BatchData(ds, 128, remainder=True)
ds = PrefetchDataZMQ(ds, 10) # TODO use PrefetchData as fallback
ds = PrefetchData(ds, 10, nr_proc=1)
cfg = PredictConfig(
model=model,
......@@ -155,6 +146,7 @@ op/variable names")
raise RuntimeError("Unsupported model type!")
if args.data:
assert os.path.isdir(os.path.join(args.data, 'val'))
eval_on_ILSVRC12(M, sess_init, args.data)
elif args.input:
run_test(M, sess_init, args.input)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment