Commit b75ed18c authored by Yuxin Wu's avatar Yuxin Wu

auto-download svhn. update resnet perf numbers.

parent fb43cf03
......@@ -16,8 +16,8 @@ To train, just run:
```bash
./imagenet-resnet.py --data /path/to/original/ILSVRC --gpu 0,1,2,3 -d 18
```
The speed is 1860 samples/s on 4 TitanX Pascal, and 1160 samples/s on 4 old TitanX, if your data is fast
enough. See the [tutorial](http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html) on how to speed up your data.
The speed is 1310 image/s on 4 Tesla M40, if your data is fast enough.
See the [tutorial](http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html) on how to speed up your data.
![imagenet](imagenet-resnet.png)
......
......@@ -7,7 +7,7 @@ import os
import numpy as np
from ...utils import logger
from ...utils.fs import get_dataset_path
from ...utils.fs import get_dataset_path, download
from ..base import RNGDataFlow
__all__ = ['SVHNDigit']
......@@ -38,8 +38,10 @@ class SVHNDigit(RNGDataFlow):
data_dir = get_dataset_path('svhn_data')
assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \
"File {} not found! Please download it from {}.".format(filename, SVHN_URL)
if not os.path.isfile(filename):
url = SVHN_URL + os.path.basename(filename)
logger.info("File {} not found! Downloading from {}.".format(filename, url))
download(url, os.path.dirname(filename))
logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename)
self.X = data['X'].transpose(3, 0, 1, 2)
......
......@@ -81,6 +81,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
if len(config.tower) > 1:
assert tf.test.is_gpu_available()
# doens't seem to improve on single GPU
if not isinstance(self._input_method, StagingInputWrapper):
devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_method = StagingInputWrapper(self._input_method, devices)
......
......@@ -58,7 +58,7 @@ def download(url, dir, filename=None):
raise
assert size > 0, "Download an empty file!"
# TODO human-readable size
print('Succesfully downloaded ' + filename + " " + str(size) + ' bytes.')
print('Succesfully downloaded ' + filename + ". " + str(size) + ' bytes.')
return fpath
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment