Commit 2d96baca authored by Yukun Chen's avatar Yukun Chen Committed by Yuxin Wu

add cifar100-convnet example and its dataflow abstraction. (#7)

* add cifar100-convnet example and its dataflow abstraction.

* cifar has both cifar10 and cifar100.

* fix a typo. rename Cifar to CifarBase.
parent 73ac38c7
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: cifar10-convnet.py # File: cifar10-convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy
import tensorflow as tf import tensorflow as tf
import argparse import argparse
import numpy as np import numpy as np
...@@ -18,11 +18,15 @@ from tensorpack.tfutils.summary import * ...@@ -18,11 +18,15 @@ from tensorpack.tfutils.summary import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
""" """
A small cifar10 convnet model. A small convnet model for cifar 10 or cifar100 dataset.
90% validation accuracy after 40k step. 90% validation accuracy after 40k step.
""" """
class Model(ModelDesc): class Model(ModelDesc):
def __init__(self, cifar_classnum):
super(Model, self).__init__()
self.cifar_classnum = cifar_classnum
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, [None, 30, 30, 3], 'input'), return [InputVar(tf.float32, [None, 30, 30, 3], 'input'),
InputVar(tf.int32, [None], 'label') InputVar(tf.int32, [None], 'label')
...@@ -53,7 +57,7 @@ class Model(ModelDesc): ...@@ -53,7 +57,7 @@ class Model(ModelDesc):
l = FullyConnected('fc1', l, 512, l = FullyConnected('fc1', l, 512,
b_init=tf.constant_initializer(0.1)) b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer # fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=self.cifar_classnum, nl=tf.identity)
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
...@@ -75,9 +79,12 @@ class Model(ModelDesc): ...@@ -75,9 +79,12 @@ class Model(ModelDesc):
add_param_summary([('.*/W', ['histogram'])]) # monitor W add_param_summary([('.*/W', ['histogram'])]) # monitor W
self.cost = tf.add_n([cost, wd_cost], name='cost') self.cost = tf.add_n([cost, wd_cost], name='cost')
def get_data(train_or_test): def get_data(train_or_test, cifar_classnum):
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
ds = dataset.Cifar10(train_or_test) if cifar_classnum == 10:
ds = dataset.Cifar10(train_or_test)
else:
ds = dataset.Cifar100(train_or_test)
if isTrain: if isTrain:
augmentors = [ augmentors = [
imgaug.RandomCrop((30, 30)), imgaug.RandomCrop((30, 30)),
...@@ -100,11 +107,11 @@ def get_data(train_or_test): ...@@ -100,11 +107,11 @@ def get_data(train_or_test):
ds = PrefetchData(ds, 10, 5) ds = PrefetchData(ds, 10, 5)
return ds return ds
def get_config(): def get_config(cifar_classnum):
# prepare dataset # prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train', cifar_classnum)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
dataset_test = get_data('test') dataset_test = get_data('test', cifar_classnum)
sess_config = get_default_sess_config(0.5) sess_config = get_default_sess_config(0.5)
...@@ -125,7 +132,7 @@ def get_config(): ...@@ -125,7 +132,7 @@ def get_config():
InferenceRunner(dataset_test, ClassificationError()) InferenceRunner(dataset_test, ClassificationError())
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(cifar_classnum),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=250, max_epoch=250,
) )
...@@ -134,6 +141,7 @@ if __name__ == '__main__': ...@@ -134,6 +141,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--classnum', help='specify cifar10 or cifar100, input 10 for cifar10 or 100 for cifar100')
args = parser.parse_args() args = parser.parse_args()
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
...@@ -145,8 +153,13 @@ if __name__ == '__main__': ...@@ -145,8 +153,13 @@ if __name__ == '__main__':
else: else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if args.classnum:
cifar_classnum = int(args.classnum)
else:
cifar_classnum = 10
with tf.Graph().as_default(): with tf.Graph().as_default():
config = get_config() config = get_config(cifar_classnum)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
......
...@@ -15,25 +15,33 @@ from ...utils import logger, get_rng ...@@ -15,25 +15,33 @@ from ...utils import logger, get_rng
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
__all__ = ['Cifar10'] __all__ = ['Cifar10', 'Cifar100']
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100 = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def maybe_download_and_extract(dest_directory): def maybe_download_and_extract(dest_directory, cifar_classnum):
"""Download and extract the tarball from Alex's website. """Download and extract the tarball from Alex's website.
copied from tensorflow example """ copied from tensorflow example """
if os.path.isdir(os.path.join(dest_directory, 'cifar-10-batches-py')): assert cifar_classnum == 10 or cifar_classnum == 100
logger.info("Found cifar10 data in {}.".format(dest_directory)) if cifar_classnum == 10:
cifar_foldername = 'cifar-10-batches-py'
else:
cifar_foldername = 'cifar-100-python'
if os.path.isdir(os.path.join(dest_directory, cifar_foldername)):
logger.info("Found cifar{} data in {}.".format(cifar_classnum, dest_directory))
return return
else: else:
DATA_URL = DATA_URL_CIFAR_10 if cifar_classnum == 10 else DATA_URL_CIFAR_100
download(DATA_URL, dest_directory) download(DATA_URL, dest_directory)
filename = DATA_URL.split('/')[-1] filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename) filepath = os.path.join(dest_directory, filename)
import tarfile import tarfile
tarfile.open(filepath, 'r:gz').extractall(dest_directory) tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar10(filenames): def read_cifar(filenames, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100
ret = [] ret = []
for fname in filenames: for fname in filenames:
fo = open(fname, 'rb') fo = open(fname, 'rb')
...@@ -42,48 +50,65 @@ def read_cifar10(filenames): ...@@ -42,48 +50,65 @@ def read_cifar10(filenames):
else: else:
dic = pickle.load(fo) dic = pickle.load(fo)
data = dic[b'data'] data = dic[b'data']
label = dic[b'labels'] if cifar_classnum == 10:
label = dic[b'labels']
IMG_NUM = 10000
elif cifar_classnum == 100:
label = dic[b'fine_labels']
IMG_NUM = 50000 if 'train' in fname else 10000
fo.close() fo.close()
for k in range(10000): for k in range(IMG_NUM):
img = data[k].reshape(3, 32, 32) img = data[k].reshape(3, 32, 32)
img = np.transpose(img, [1, 2, 0]) img = np.transpose(img, [1, 2, 0])
ret.append([img, label[k]]) ret.append([img, label[k]])
return ret return ret
def get_filenames(dir): def get_filenames(dir, cifar_classnum):
filenames = [os.path.join( assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10:
filenames = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)] dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)]
filenames.append(os.path.join( filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch')) dir, 'cifar-10-batches-py', 'test_batch'))
elif cifar_classnum == 100:
filenames = [os.path.join(
dir, 'cifar-100-python', 'train')]
filenames.append(os.path.join(
dir, 'cifar-100-python', 'test'))
return filenames return filenames
class Cifar10(DataFlow): class CifarBase(DataFlow):
""" """
Return [image, label], Return [image, label],
image is 32x32x3 in the range [0,255] image is 32x32x3 in the range [0,255]
""" """
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None, cifar_classnum=10):
""" """
Args: Args:
train_or_test: string either 'train' or 'test' train_or_test: string either 'train' or 'test'
shuffle: default to True shuffle: default to True
""" """
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum
if dir is None: if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'cifar10_data') dir = os.path.join(os.path.dirname(__file__), 'cifar-10-batches-py'
maybe_download_and_extract(dir) if cifar_classnum==10 else 'cifar100_data')
maybe_download_and_extract(dir, self.cifar_classnum)
fnames = get_filenames(dir) if self.cifar_classnum == 10:
fnames = get_filenames(dir, 10)
else:
fnames = get_filenames(dir, 100)
if train_or_test == 'train': if train_or_test == 'train':
self.fs = fnames[:5] self.fs = fnames[:5] if cifar_classnum==10 else fnames[:1]
else: else:
self.fs = [fnames[-1]] self.fs = [fnames[-1]]
for f in self.fs: for f in self.fs:
if not os.path.isfile(f): if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f) raise ValueError('Failed to find file: ' + f)
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.data = read_cifar(self.fs, cifar_classnum)
self.dir = dir self.dir = dir
self.data = read_cifar10(self.fs)
self.shuffle = shuffle self.shuffle = shuffle
self.rng = get_rng(self) self.rng = get_rng(self)
...@@ -104,8 +129,8 @@ class Cifar10(DataFlow): ...@@ -104,8 +129,8 @@ class Cifar10(DataFlow):
""" """
return a mean image of all (train and test) images of size 32x32x3 return a mean image of all (train and test) images of size 32x32x3
""" """
fnames = get_filenames(self.dir) fnames = get_filenames(self.dir, self.cifar_classnum)
all_imgs = [x[0] for x in read_cifar10(fnames)] all_imgs = [x[0] for x in read_cifar(fnames, self.cifar_classnum)]
arr = np.array(all_imgs, dtype='float32') arr = np.array(all_imgs, dtype='float32')
mean = np.mean(arr, axis=0) mean = np.mean(arr, axis=0)
return mean return mean
...@@ -117,6 +142,14 @@ class Cifar10(DataFlow): ...@@ -117,6 +142,14 @@ class Cifar10(DataFlow):
mean = self.get_per_pixel_mean() mean = self.get_per_pixel_mean()
return np.mean(mean, axis=(0,1)) return np.mean(mean, axis=(0,1))
class Cifar10(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar10, self).__init__(train_or_test, shuffle, dir, 10)
class Cifar100(CifarBase):
def __init__(self, train_or_test, shuffle=True, dir=None):
super(Cifar100, self).__init__(train_or_test, shuffle, dir, 100)
if __name__ == '__main__': if __name__ == '__main__':
ds = Cifar10('train') ds = Cifar10('train')
from tensorpack.dataflow.dftools import dump_dataset_images from tensorpack.dataflow.dftools import dump_dataset_images
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment