Commit dc782068 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 67786cbb
......@@ -56,12 +56,9 @@ To eval on ILSVRC12, `path/to/ILSVRC12` must have a subdirectory named 'val' con
Please use [github issues](https://github.com/ppwwyyxx/tensorpack/issues) for any issues related to the code.
Send email to the authors for other questions related to the paper.
Note that although the it uses low bitwidth weights, activations and gradients, these values
here are still represented in `tf.float32`, since TensorFlow doesn't natively support low bitwidth computation.
## Citation
If you use our models in your research, please cite:
If you use our code or models in your research, please cite:
```
@article{zhou2016dorefa,
author = {Shuchang Zhou and Zekun Ni and Xinyu Zhou and He Wen and Yuxin Wu and Yuheng Zou},
......
......@@ -101,7 +101,9 @@ class ILSVRCMeta(object):
class ILSVRC12(DataFlow):
def __init__(self, dir, name, meta_dir=None, shuffle=True):
"""
name: 'train' or 'val' or 'test'
:param name: 'train' or 'val' or 'test'
:param dir: A directory containing a subdir named `name`, inside which the
original ILSVRC12_`name`.tar gets decompressed.
"""
assert name in ['train', 'test', 'val']
self.full_dir = os.path.join(dir, name)
......@@ -136,7 +138,9 @@ class ILSVRC12(DataFlow):
im = np.expand_dims(im, 2).repeat(3,2)
yield [im, tp[1]]
# TODO more generally, just CaffeLMDB
class ILSVRC12CaffeLMDB(DataFlow):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def __init__(self, lmdb_dir, shuffle=True):
"""
:param shuffle: about 3 times slower
......@@ -150,7 +154,7 @@ class ILSVRC12CaffeLMDB(DataFlow):
self._txn = self._lmdb.begin()
self._size = self._txn.stat()['entries']
if shuffle:
with timed_operation("Loading LMDB keys ..."):
with timed_operation("Loading LMDB keys ...", log_start=True):
self.keys = [k for k, _ in self._txn.cursor()]
def reset_state(self):
......
......@@ -25,7 +25,9 @@ def download(url, dir):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(fname, float(count * block_size) / float(total_size) * 100.0))
(fname,
min(float(count * block_size)/ total_size,
1.0) * 100.0))
sys.stdout.flush()
try:
fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=_progress)
......
......@@ -18,7 +18,7 @@ __all__ = ['total_timer', 'timed_operation', 'print_total_timer']
@contextmanager
def timed_operation(msg, log_start=False):
if log_start:
logger.info('start {} ...'.format(msg))
logger.info('Start {} ...'.format(msg))
start = time.time()
yield
logger.info('{} finished, time={:.2f}sec.'.format(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment