Commit 88ed2c24 authored by Yuxin Wu's avatar Yuxin Wu

update dorefa to use floor instead of round

parent 70e14a6b
......@@ -3,7 +3,8 @@ Code and model for the paper:
[DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients](http://arxiv.org/abs/1606.06160), by Zhou et al.
We hosted a demo at CVPR16 on behalf of Megvii, Inc, running a real-time 1/4-VGG size DoReFa-Net on ARM and half-VGG size DoReFa-Net on FPGA.
We're not planning to release those runtime bit-op libraries for now. In this repo, bit operations are run in float32.
We're not planning to release our C++ runtime for bit-operations.
In this repo, bit operations are performed through `tf.float32`.
Pretrained model for 1-2-6-AlexNet is available at
[google drive](https://drive.google.com/a/%20megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ).
......@@ -13,7 +14,7 @@ It's provided in the format of numpy dictionary, so it should be very easy to po
To use the script. You'll need:
+ TensorFlow 0.10,0.11rc1,0.11rc2. 0.11 is not supported due to [TF bug](https://github.com/tensorflow/tensorflow/issues/5888)
+ TensorFlow >= 0.10
+ OpenCV bindings for Python
......
......@@ -17,7 +17,7 @@ def get_dorefa(bitW, bitA, bitG):
def quantize(x, k):
n = float(2**k-1)
with G.gradient_override_map({"Floor": "Identity"}):
return tf.round(x * n) / n
return tf.floor(x * n + 0.5) / n
def fw(x):
if bitW == 32:
......
......@@ -30,7 +30,7 @@ Accuracy:
With (W,A,G)=(32,32,32), error is about 2.9%.
Speed:
About 18 iteration/s on 1 Tesla M40. (4721 iterations / epoch)
30~35 iteration/s on 1 TitanX Pascal. (4721 iterations / epoch)
To Run:
./svhn-digit-dorefa.py --dorefa 1,2,4
......@@ -45,8 +45,9 @@ class Model(ModelDesc):
return [InputVar(tf.float32, [None, 40, 40, 3], 'input'),
InputVar(tf.int32, [None], 'label') ]
def _build_graph(self, input_vars, is_training):
def _build_graph(self, input_vars):
image, label = input_vars
is_training = get_current_tower_context().is_training
fw, fa, fg = get_dorefa(BITW, BITA, BITG)
# monkey-patch tf.get_variable to apply fw
......
......@@ -15,3 +15,6 @@ from tensorpack.predict import *
if int(numpy.__version__.split('.')[1]) < 9:
logger.warn("Numpy < 1.9 could be extremely slow on some tasks.")
if get_tf_version() < 10:
logger.error("tensorpack requires TensorFlow >= 0.10")
......@@ -8,6 +8,7 @@ import multiprocessing as mp
import six
from six.moves import range, map
from .base import DataFlow
from ..utils import get_tqdm, logger
from ..utils.concurrency import DIE
from ..utils.serialize import dumps
......@@ -43,6 +44,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dump_dataflow_to_lmdb(ds, lmdb_path):
assert isinstance(ds, DataFlow), type(ds)
isdir = os.path.isdir(lmdb_path)
if isdir:
assert not os.path.isfile(os.path.join(lmdb_path, 'data.mdb')), "LMDB file exists!"
......@@ -52,16 +54,20 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True) # need sync() at the end
with get_tqdm(total=ds.size()) as pbar:
try:
sz = ds.size()
except NotImplementedError:
sz = 0
with get_tqdm(total=sz) as pbar:
with db.begin(write=True) as txn:
for idx, dp in enumerate(ds.get_data()):
txn.put(six.binary_type(idx), dumps(dp))
pbar.update()
keys = list(map(six.binary_type, range(idx + 1)))
txn.put('__keys__', dumps(keys))
logger.info("Flushing database ...")
db.sync()
db.close()
logger.info("Flushing database ...")
db.sync()
def dataflow_to_process_queue(ds, size, nr_consumer):
......
......@@ -9,6 +9,7 @@ from tensorflow.python.training import moving_averages
from copy import copy
import re
from ..tfutils.common import get_tf_version
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ._common import layer_register
......@@ -177,4 +178,8 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
else:
return tf.identity(xn, name='output')
BatchNorm = BatchNormV2
if get_tf_version() >= 11:
BatchNorm = BatchNormV2
else:
logger.warn("BatchNorm might be faster if you update TensorFlow")
BatchNorm = BatchNormV1
......@@ -19,7 +19,8 @@ __all__ = ['get_default_sess_config',
'backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection']
'freeze_collection',
'get_tf_version']
def get_default_sess_config(mem_fraction=0.99):
"""
......@@ -104,3 +105,6 @@ def freeze_collection(keys):
backup = backup_collection(keys)
yield
restore_collection(backup)
def get_tf_version():
return int(tf.__version__.split('.')[1])
......@@ -8,7 +8,7 @@ from six.moves import urllib
import errno
from . import logger
__all__ = ['mkdir_p', 'download']
__all__ = ['mkdir_p', 'download', 'recursive_walk']
def mkdir_p(dirname):
""" make a dir recursively, but do nothing if the dir exists"""
......@@ -21,7 +21,6 @@ def mkdir_p(dirname):
if e.errno != errno.EEXIST:
raise e
def download(url, dir):
mkdir_p(dir)
fname = url.split('/')[-1]
......@@ -46,5 +45,10 @@ def download(url, dir):
print('Succesfully downloaded ' + fname + " " + str(size) + ' bytes.')
return fpath
def recursive_walk(rootdir):
for r, dirs, files in os.walk(rootdir):
for f in files:
yield os.path.join(r, f)
if __name__ == '__main__':
download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment