Commit 2d96baca authored by Yukun Chen's avatar Yukun Chen Committed by Yuxin Wu

add cifar100-convnet example and its dataflow abstraction. (#7)

* add cifar100-convnet example and its dataflow abstraction.

* cifar has both cifar10 and cifar100.

* fix a typo. rename Cifar to CifarBase.
parent 73ac38c7
......@@ -2,7 +2,7 @@
# -*- coding: UTF-8 -*-
# File: cifar10-convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy
import tensorflow as tf
import argparse
import numpy as np
......@@ -18,11 +18,15 @@ from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
"""
A small cifar10 convnet model.
A small convnet model for cifar 10 or cifar100 dataset.
90% validation accuracy after 40k step.
"""
class Model(ModelDesc):
def __init__(self, cifar_classnum):
super(Model, self).__init__()
self.cifar_classnum = cifar_classnum
def _get_input_vars(self):
return [InputVar(tf.float32, [None, 30, 30, 3], 'input'),
InputVar(tf.int32, [None], 'label')
......@@ -53,7 +57,7 @@ class Model(ModelDesc):
l = FullyConnected('fc1', l, 512,
b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
logits = FullyConnected('linear', l, out_dim=self.cifar_classnum, nl=tf.identity)
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
......@@ -75,9 +79,12 @@ class Model(ModelDesc):
add_param_summary([('.*/W', ['histogram'])]) # monitor W
self.cost = tf.add_n([cost, wd_cost], name='cost')
def get_data(train_or_test):
def get_data(train_or_test, cifar_classnum):
isTrain = train_or_test == 'train'
if cifar_classnum == 10:
ds = dataset.Cifar10(train_or_test)
else:
ds = dataset.Cifar100(train_or_test)
if isTrain:
augmentors = [
imgaug.RandomCrop((30, 30)),
......@@ -100,11 +107,11 @@ def get_data(train_or_test):
ds = PrefetchData(ds, 10, 5)
return ds
def get_config():
def get_config(cifar_classnum):
# prepare dataset
dataset_train = get_data('train')
dataset_train = get_data('train', cifar_classnum)
step_per_epoch = dataset_train.size()
dataset_test = get_data('test')
dataset_test = get_data('test', cifar_classnum)
sess_config = get_default_sess_config(0.5)
......@@ -125,7 +132,7 @@ def get_config():
InferenceRunner(dataset_test, ClassificationError())
]),
session_config=sess_config,
model=Model(),
model=Model(cifar_classnum),
step_per_epoch=step_per_epoch,
max_epoch=250,
)
......@@ -134,6 +141,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
parser.add_argument('--classnum', help='specify cifar10 or cifar100, input 10 for cifar10 or 100 for cifar100')
args = parser.parse_args()
basename = os.path.basename(__file__)
......@@ -145,8 +153,13 @@ if __name__ == '__main__':
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if args.classnum:
cifar_classnum = int(args.classnum)
else:
cifar_classnum = 10
with tf.Graph().as_default():
config = get_config()
config = get_config(cifar_classnum)
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
......
......@@ -15,25 +15,33 @@ from ...utils import logger, get_rng
from ...utils.fs import download
from ..base import DataFlow
__all__ = ['Cifar10']
__all__ = ['Cifar10', 'Cifar100']
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100 = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def maybe_download_and_extract(dest_directory):
def maybe_download_and_extract(dest_directory, cifar_classnum):
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
if os.path.isdir(os.path.join(dest_directory, 'cifar-10-batches-py')):
logger.info("Found cifar10 data in {}.".format(dest_directory))
assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10:
cifar_foldername = 'cifar-10-batches-py'
else:
cifar_foldername = 'cifar-100-python'
if os.path.isdir(os.path.join(dest_directory, cifar_foldername)):
logger.info("Found cifar{} data in {}.".format(cifar_classnum, dest_directory))
return
else:
DATA_URL = DATA_URL_CIFAR_10 if cifar_classnum == 10 else DATA_URL_CIFAR_100
download(DATA_URL, dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
import tarfile
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar10(filenames):
def read_cifar(filenames, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100
ret = []
for fname in filenames:
fo = open(fname, 'rb')
......@@ -42,48 +50,65 @@ def read_cifar10(filenames):
else:
dic = pickle.load(fo)
data = dic[b'data']
if cifar_classnum == 10:
label = dic[b'labels']
IMG_NUM = 10000
elif cifar_classnum == 100:
label = dic[b'fine_labels']
IMG_NUM = 50000 if 'train' in fname else 10000
fo.close()
for k in range(10000):
for k in range(IMG_NUM):
img = data[k].reshape(3, 32, 32)
img = np.transpose(img, [1, 2, 0])
ret.append([img, label[k]])
return ret
def get_filenames(dir):
def get_filenames(dir, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10:
filenames = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)]
filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch'))
elif cifar_classnum == 100:
filenames = [os.path.join(
dir, 'cifar-100-python', 'train')]
filenames.append(os.path.join(
dir, 'cifar-100-python', 'test'))
return filenames
class Cifar10(DataFlow):
class CifarBase(DataFlow):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def __init__(self, train_or_test, shuffle=True, dir=None):
def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10):
"""
Args:
train_or_test: string either 'train' or 'test'
shuffle: default to True
"""
assert train_or_test in ['train', 'test']
assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum
if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'cifar10_data')
maybe_download_and_extract(dir)
fnames = get_filenames(dir)
dir = os.path.join(os.path.dirname(__file__), 'cifar-10-batches-py'
if cifar_classnum==10 else 'cifar100_data')
maybe_download_and_extract(dir, self.cifar_classnum)
if self.cifar_classnum == 10:
fnames = get_filenames(dir, 10)
else:
fnames = get_filenames(dir, 100)
if train_or_test == 'train':
self.fs = fnames[:5]
self.fs = fnames[:5] if cifar_classnum==10 else fnames[:1]
else:
self.fs = [fnames[-1]]
for f in self.fs:
if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f)
self.train_or_test = train_or_test
self.data = read_cifar(self.fs, cifar_classnum)
self.dir = dir
self.data = read_cifar10(self.fs)
self.shuffle = shuffle
self.rng = get_rng(self)
......@@ -104,8 +129,8 @@ class Cifar10(DataFlow):
"""
return a mean image of all (train and test) images of size 32x32x3
"""
fnames = get_filenames(self.dir)
all_imgs = [x[0] for x in read_cifar10(fnames)]
fnames = get_filenames(self.dir, self.cifar_classnum)
all_imgs = [x[0] for x in read_cifar(fnames, self.cifar_classnum)]
arr = np.array(all_imgs, dtype='float32')
mean = np.mean(arr, axis=0)
return mean
......@@ -117,6 +142,14 @@ class Cifar10(DataFlow):
mean = self.get_per_pixel_mean()
return np.mean(mean, axis=(0,1))
class Cifar10(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10)
class Cifar100(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
if __name__ == '__main__':
ds = Cifar10('train')
from tensorpack.dataflow.dftools import dump_dataset_images
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment