Commit 8902af90 authored by Yuxin Wu's avatar Yuxin Wu

more consistent with official cifar code

parent daf368dc
......@@ -19,7 +19,7 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = 500
MIN_AFTER_DEQUEUE = 20000 # a large number, as in the official example
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training):
......@@ -28,7 +28,10 @@ def get_model(inputs, is_training):
image, label = inputs
#if is_training: # slow
if is_training: # slow?
image, label = tf.train.shuffle_batch(
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=6, enqueue_many=False)
## augmentations
#image, label = tf.train.slice_input_producer(
#[image, label], name='slice_queue')
......@@ -80,20 +83,23 @@ def get_config():
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3])
logger.set_logger_dir(log_dir)
import cv2
dataset_train = dataset.Cifar10('train')
augmentor = imgaug.AugmentorList([
augmentors = [
RandomCrop((24, 24)),
Flip(horiz=True),
BrightnessAdd(0.25),
Contrast((0.2,1.8)),
PerImageWhitening()
])
dataset_train = MapData(dataset_train, lambda img:
augmentor.augment(imgaug.Image(img)).arr)
]
dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128)
augmentors = [
CenterCrop((24, 24)),
PerImageWhitening()
]
dataset_test = dataset.Cifar10('test')
dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24)))
dataset_test = AugmentImageComponent(dataset_test, augmentors)
dataset_test = BatchData(dataset_test, 128)
step_per_epoch = dataset_train.size()
#step_per_epoch = 20
......@@ -109,19 +115,19 @@ def get_config():
tf.placeholder(
tf.int32, shape=(None,), name='label')
]
input_queue = tf.RandomShuffleQueue(
100, 50, [x.dtype for x in input_vars], name='queue')
input_queue = tf.FIFOQueue(
50, [x.dtype for x in input_vars], name='queue')
lr = tf.train.exponential_decay(
learning_rate=1e-4,
learning_rate=1e-1,
global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 50,
decay_steps=dataset_train.size() * 200,
decay_rate=0.1, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
optimizer=tf.train.GradientDescentOptimizer(lr),
callbacks=Callbacks([
SummaryWriter(),
PeriodicSaver(),
......@@ -131,6 +137,7 @@ def get_config():
inputs=input_vars,
input_queue=input_queue,
get_model_func=get_model,
batched_model_input=False,
step_per_epoch=step_per_epoch,
max_epoch=100,
)
......
......@@ -42,13 +42,12 @@ def get_model(inputs, is_training):
image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel
#if is_training: # slow
#if is_training:
## augmentations
#image, label = tf.train.slice_input_producer(
#[image, label], name='slice_queue')
#image = tf.image.random_brightness(image, 0.1)
#image, label = tf.train.shuffle_batch(
#[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
#[image, label], shuffle=False, name='slice_queue')
#image, label = tf.train.batch(
#[image, label], BATCH_SIZE, capacity=CAPACITY,
#num_threads=2, enqueue_many=False)
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
......@@ -98,7 +97,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
step_per_epoch = 2
#step_per_epoch = 30
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config = get_default_sess_config()
......
......@@ -5,8 +5,10 @@
import numpy as np
from .base import DataFlow
from imgaug import AugmentorList, Image
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData']
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'AugmentImageComponent']
class BatchData(DataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -101,3 +103,17 @@ class MapData(DataFlow):
d = list(dp)
dp[self.index] = self.func(dp[self.index])
yield dp
def AugmentImageComponent(ds, augmentors, index=0):
"""
Augment the image in each data point
Args:
ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
aug = AugmentorList(augmentors)
return MapData(
ds,
lambda img: aug.augment(Image(img)).arr,
index)
......@@ -60,6 +60,7 @@ class AugmentorList(ImageAugmentor):
self.augs = augmentors
def _augment(self, img):
assert img.arr.ndim in [2, 3]
img.arr = img.arr.astype('float32') / 255.0
for aug in self.augs:
aug.augment(img)
......@@ -5,12 +5,12 @@
from .base import ImageAugmentor
__all__ = ['RandomCrop']
__all__ = ['RandomCrop', 'CenterCrop', 'Resize']
class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """
def __init__(self, crop_shape):
"""
Randomly crop the image into a smaller one
Args:
crop_shape: shape in (h, w)
"""
......@@ -24,4 +24,29 @@ class RandomCrop(ImageAugmentor):
if img.coords:
raise NotImplementedError()
class CenterCrop(ImageAugmentor):
""" Crop the image in the center"""
def __init__(self, crop_shape):
self._init(locals())
def _augment(self, img):
orig_shape = img.arr.shape
h0 = (orig_shape[0] - self.crop_shape[0]) * 0.5
w0 = (orig_shape[1] - self.crop_shape[1]) * 0.5
img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
if img.coords:
raise NotImplementedError()
class Resize(ImageAugmentor):
"""Resize image to a target size"""
def __init__(self, shape):
"""
Args:
shape: (w, h)
"""
self._init(locals())
def _augment(self, img):
img.arr = cv2.resize(
img.arr, self.shape,
interpolation=cv2.INTER_CUBIC)
......@@ -33,12 +33,13 @@ class Contrast(ImageAugmentor):
r = self._rand_range(*self.factor_range)
mean = np.mean(arr, axis=(0,1), keepdims=True)
img.arr = (arr - mean) * r + mean
img.arr = np.clip(img.arr, 0, 1)
class PerImageWhitening(ImageAugmentor):
"""
Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels))
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
"""
def __init__(self):
pass
......@@ -46,5 +47,5 @@ class PerImageWhitening(ImageAugmentor):
def _augment(self, img):
mean = np.mean(img.arr, axis=(0,1), keepdims=True)
std = np.std(img.arr, axis=(0,1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape[:2])))
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape)))
img.arr = (img.arr - mean) / std
......@@ -36,6 +36,11 @@ class TrainConfig(object):
with capacity 5
get_model_func: a function taking `inputs` and `is_training` and
return a tuple of output list as well as the cost to minimize
batched_model_input: boolean. If yes, `get_model_func` expected batched
input in training. Otherwise, expect single data point in
training, so that you may do pre-processing and batch them
later with batch ops. It's suggested that you do all
preprocessing in dataset as that is usually faster.
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
......@@ -59,6 +64,7 @@ class TrainConfig(object):
assert_type(self.input_queue, tf.QueueBase)
assert self.input_queue.dtypes == [x.dtype for x in self.inputs]
self.get_model_func = kwargs.pop('get_model_func')
self.batched_model_input = kwargs.pop('batched_model_input', True)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0
......@@ -89,11 +95,16 @@ def start_train(config):
input_queue = config.input_queue
callbacks = config.callbacks
enqueue_op = input_queue.enqueue(tuple(input_vars))
if config.batched_model_input:
enqueue_op = input_queue.enqueue(input_vars)
model_inputs = input_queue.dequeue()
# set dequeue shape
for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape())
else:
enqueue_op = input_queue.enqueue_many(input_vars)
model_inputs = input_queue.dequeue()
for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape().as_list()[1:])
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
# build graph
......@@ -125,6 +136,8 @@ def start_train(config):
for step in xrange(config.step_per_epoch):
if coord.should_stop():
return
# TODO if no one uses trigger_step, train_op can be
# faster, see: https://github.com/soumith/convnet-benchmarks/pull/67/files
fetches = [train_op, cost_var] + output_vars + model_inputs
results = sess.run(fetches)
cost = results[1]
......
......@@ -111,9 +111,9 @@ class CallbackTimeLogger(object):
msgs = []
for name, t in self.times:
if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{}".format(name, t))
msgs.append("{}:{}sec".format(name, t))
logger.info(
"Callbacks took {} sec. {}".format(
"Callbacks took {} sec in total. {}".format(
self.tot, ' '.join(msgs)))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment