Commit 60e52b94 authored by Yuxin Wu's avatar Yuxin Wu

download / ilsvrcmeta

parent d979ab78
......@@ -18,6 +18,7 @@ from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
from tensorpack.dataflow.dataset import ILSVRCMeta
BATCH_SIZE = 10
MIN_AFTER_DEQUEUE = 500
......@@ -132,12 +133,15 @@ def run_test(path, input):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (227, 227))
im = np.reshape(im, (1, 227, 227, 3)).astype('float32')
im = im - 110
outputs = predict_func([im])[0]
prob = outputs[0]
print prob.shape
ret = prob.argsort()[-10:][::-1]
print ret
meta = ILSVRCMeta().get_synset_words_1000()
print [meta[k] for k in ret]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
......
......@@ -18,6 +18,7 @@ from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
from tensorpack.dataflow.dataset import ILSVRCMeta
class Model(ModelDesc):
def _get_input_vars(self):
......@@ -104,12 +105,15 @@ def run_test(path, input):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (224, 224))
im = np.reshape(im, (1, 224, 224, 3)).astype('float32')
im = im - 110
outputs = predict_func([im])[0]
prob = outputs[0]
print prob.shape
ret = prob.argsort()[-10:][::-1]
print ret
meta = ILSVRCMeta().get_synset_words_1000()
print [meta[k] for k in ret]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', default='0',
......
mnist_data
cifar10_data
svhn_data
ilsvrc_metadata
......@@ -13,6 +13,7 @@ import tarfile
import logging
from ...utils import logger, get_rng
from ...utils.fs import download
from ..base import DataFlow
__all__ = ['Cifar10']
......@@ -23,22 +24,13 @@ DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
def maybe_download_and_extract(dest_directory):
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if os.path.isdir(os.path.join(dest_directory, 'cifar-10-batches-py')):
logger.info("Found cifar10 data in {}.".format(dest_directory))
return
else:
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filepath,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, reporthook=_progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
download(URL, dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar10(filenames):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ilsvrc.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import tarfile
from ...utils.fs import mkdir_p, download
__all__ = ['ILSVRCMeta']
CAFFE_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class ILSVRCMeta(object):
def __init__(self, dir=None):
if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'ilsvrc_metadata')
self.dir = dir
mkdir_p(self.dir)
def get_synset_words_1000(self):
fname = os.path.join(self.dir, 'synset_words.txt')
if not os.path.isfile(fname):
self.download_caffe_meta()
assert os.path.isfile(fname)
lines = [x.strip() for x in open(fname).readlines()]
return dict(enumerate(lines))
def download_caffe_meta(self):
fpath = download(CAFFE_URL, self.dir)
tarfile.open(fpath, 'r:gz').extractall(self.dir)
if __name__ == '__main__':
meta = ILSVRCMeta()
print meta.get_synset_words_1000()
......@@ -10,6 +10,7 @@ import numpy
from six.moves import urllib, range
from ...utils import logger
from ...utils.fs import download
from ..base import DataFlow
__all__ = ['Mnist']
......@@ -20,14 +21,10 @@ SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
logger.info("Downloading mnist data to {}...".format(filepath))
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
logger.info('Successfully downloaded to ' + filename)
download(SOURCE_URL + filename, work_directory)
return filepath
def _read32(bytestream):
......
......@@ -31,7 +31,7 @@ class SVHNDigit(DataFlow):
assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \
"File {} not found! Download it from \
"File {} not found! Please download it from \
http://ufldl.stanford.edu/housenumbers/".format(filename)
logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename)
......
......@@ -3,7 +3,8 @@
# File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import os, sys
from six.moves import urllib
def mkdir_p(dirname):
assert dirname is not None
......@@ -15,3 +16,21 @@ def mkdir_p(dirname):
if e.errno != 17:
raise e
def download(url, dir):
mkdir_p(dir)
fname = url.split('/')[-1]
fpath = os.path.join(dir, fname)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(fname, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
fpath, _ = urllib.request.urlretrieve(url, fpath, reporthook=_progress)
statinfo = os.stat(fpath)
sys.stdout.write('\n')
print('Succesfully downloaded ' + fname + " " + str(statinfo.st_size) + ' bytes.')
return fpath
if __name__ == '__main__':
download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment