Commit 60e52b94 authored by Yuxin Wu's avatar Yuxin Wu

download / ilsvrcmeta

parent d979ab78
...@@ -18,6 +18,7 @@ from tensorpack.tfutils.symbolic_functions import * ...@@ -18,6 +18,7 @@ from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow.dataset import ILSVRCMeta
BATCH_SIZE = 10 BATCH_SIZE = 10
MIN_AFTER_DEQUEUE = 500 MIN_AFTER_DEQUEUE = 500
...@@ -132,12 +133,15 @@ def run_test(path, input): ...@@ -132,12 +133,15 @@ def run_test(path, input):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (227, 227)) im = cv2.resize(im, (227, 227))
im = np.reshape(im, (1, 227, 227, 3)).astype('float32') im = np.reshape(im, (1, 227, 227, 3)).astype('float32')
im = im - 110
outputs = predict_func([im])[0] outputs = predict_func([im])[0]
prob = outputs[0] prob = outputs[0]
print prob.shape
ret = prob.argsort()[-10:][::-1] ret = prob.argsort()[-10:][::-1]
print ret print ret
meta = ILSVRCMeta().get_synset_words_1000()
print [meta[k] for k in ret]
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
......
...@@ -18,6 +18,7 @@ from tensorpack.tfutils.symbolic_functions import * ...@@ -18,6 +18,7 @@ from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow.dataset import ILSVRCMeta
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
...@@ -104,12 +105,15 @@ def run_test(path, input): ...@@ -104,12 +105,15 @@ def run_test(path, input):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (224, 224)) im = cv2.resize(im, (224, 224))
im = np.reshape(im, (1, 224, 224, 3)).astype('float32') im = np.reshape(im, (1, 224, 224, 3)).astype('float32')
im = im - 110
outputs = predict_func([im])[0] outputs = predict_func([im])[0]
prob = outputs[0] prob = outputs[0]
print prob.shape
ret = prob.argsort()[-10:][::-1] ret = prob.argsort()[-10:][::-1]
print ret print ret
meta = ILSVRCMeta().get_synset_words_1000()
print [meta[k] for k in ret]
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', default='0', parser.add_argument('--gpu', default='0',
......
mnist_data mnist_data
cifar10_data cifar10_data
svhn_data svhn_data
ilsvrc_metadata
...@@ -13,6 +13,7 @@ import tarfile ...@@ -13,6 +13,7 @@ import tarfile
import logging import logging
from ...utils import logger, get_rng from ...utils import logger, get_rng
from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
__all__ = ['Cifar10'] __all__ = ['Cifar10']
...@@ -23,22 +24,13 @@ DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' ...@@ -23,22 +24,13 @@ DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
def maybe_download_and_extract(dest_directory): def maybe_download_and_extract(dest_directory):
"""Download and extract the tarball from Alex's website. """Download and extract the tarball from Alex's website.
copied from tensorflow example """ copied from tensorflow example """
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if os.path.isdir(os.path.join(dest_directory, 'cifar-10-batches-py')): if os.path.isdir(os.path.join(dest_directory, 'cifar-10-batches-py')):
logger.info("Found cifar10 data in {}.".format(dest_directory)) logger.info("Found cifar10 data in {}.".format(dest_directory))
return return
else: else:
def _progress(count, block_size, total_size): download(URL, dest_directory)
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filepath, filename = DATA_URL.split('/')[-1]
float(count * block_size) / float(total_size) * 100.0)) filepath = os.path.join(dest_directory, filename)
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, reporthook=_progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory) tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar10(filenames): def read_cifar10(filenames):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ilsvrc.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import tarfile
from ...utils.fs import mkdir_p, download
__all__ = ['ILSVRCMeta']
CAFFE_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class ILSVRCMeta(object):
def __init__(self, dir=None):
if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'ilsvrc_metadata')
self.dir = dir
mkdir_p(self.dir)
def get_synset_words_1000(self):
fname = os.path.join(self.dir, 'synset_words.txt')
if not os.path.isfile(fname):
self.download_caffe_meta()
assert os.path.isfile(fname)
lines = [x.strip() for x in open(fname).readlines()]
return dict(enumerate(lines))
def download_caffe_meta(self):
fpath = download(CAFFE_URL, self.dir)
tarfile.open(fpath, 'r:gz').extractall(self.dir)
if __name__ == '__main__':
meta = ILSVRCMeta()
print meta.get_synset_words_1000()
...@@ -10,6 +10,7 @@ import numpy ...@@ -10,6 +10,7 @@ import numpy
from six.moves import urllib, range from six.moves import urllib, range
from ...utils import logger from ...utils import logger
from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
__all__ = ['Mnist'] __all__ = ['Mnist']
...@@ -20,14 +21,10 @@ SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' ...@@ -20,14 +21,10 @@ SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory): def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here.""" """Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename) filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath): if not os.path.exists(filepath):
logger.info("Downloading mnist data to {}...".format(filepath)) logger.info("Downloading mnist data to {}...".format(filepath))
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) download(SOURCE_URL + filename, work_directory)
statinfo = os.stat(filepath)
logger.info('Successfully downloaded to ' + filename)
return filepath return filepath
def _read32(bytestream): def _read32(bytestream):
......
...@@ -31,7 +31,7 @@ class SVHNDigit(DataFlow): ...@@ -31,7 +31,7 @@ class SVHNDigit(DataFlow):
assert name in ['train', 'test', 'extra'], name assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat') filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \ assert os.path.isfile(filename), \
"File {} not found! Download it from \ "File {} not found! Please download it from \
http://ufldl.stanford.edu/housenumbers/".format(filename) http://ufldl.stanford.edu/housenumbers/".format(filename)
logger.info("Loading {} ...".format(filename)) logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename) data = scipy.io.loadmat(filename)
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# File: fs.py # File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os, sys
from six.moves import urllib
def mkdir_p(dirname): def mkdir_p(dirname):
assert dirname is not None assert dirname is not None
...@@ -15,3 +16,21 @@ def mkdir_p(dirname): ...@@ -15,3 +16,21 @@ def mkdir_p(dirname):
if e.errno != 17: if e.errno != 17:
raise e raise e
def download(url, dir):
mkdir_p(dir)
fname = url.split('/')[-1]
fpath = os.path.join(dir, fname)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(fname, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=_progress)
statinfo = os.stat(fpath)
sys.stdout.write('\n')
print('Succesfully downloaded ' + fname + " " + str(statinfo.st_size) + ' bytes.')
return fpath
if __name__ == '__main__':
download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment