Commit 8902af90 authored by Yuxin Wu's avatar Yuxin Wu

more consistent with official cifar code

parent daf368dc
...@@ -19,7 +19,7 @@ from tensorpack.dataflow import * ...@@ -19,7 +19,7 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
BATCH_SIZE = 128 BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = 500 MIN_AFTER_DEQUEUE = 20000 # a large number, as in the official example
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training): def get_model(inputs, is_training):
...@@ -28,7 +28,10 @@ def get_model(inputs, is_training): ...@@ -28,7 +28,10 @@ def get_model(inputs, is_training):
image, label = inputs image, label = inputs
#if is_training: # slow if is_training: # slow?
image, label = tf.train.shuffle_batch(
[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
num_threads=6, enqueue_many=False)
## augmentations ## augmentations
#image, label = tf.train.slice_input_producer( #image, label = tf.train.slice_input_producer(
#[image, label], name='slice_queue') #[image, label], name='slice_queue')
...@@ -80,20 +83,23 @@ def get_config(): ...@@ -80,20 +83,23 @@ def get_config():
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3]) log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3])
logger.set_logger_dir(log_dir) logger.set_logger_dir(log_dir)
import cv2
dataset_train = dataset.Cifar10('train') dataset_train = dataset.Cifar10('train')
augmentor = imgaug.AugmentorList([ augmentors = [
RandomCrop((24, 24)), RandomCrop((24, 24)),
Flip(horiz=True), Flip(horiz=True),
BrightnessAdd(0.25), BrightnessAdd(0.25),
Contrast((0.2,1.8)), Contrast((0.2,1.8)),
PerImageWhitening() PerImageWhitening()
]) ]
dataset_train = MapData(dataset_train, lambda img: dataset_train = AugmentImageComponent(dataset_train, augmentors)
augmentor.augment(imgaug.Image(img)).arr)
dataset_train = BatchData(dataset_train, 128) dataset_train = BatchData(dataset_train, 128)
augmentors = [
CenterCrop((24, 24)),
PerImageWhitening()
]
dataset_test = dataset.Cifar10('test') dataset_test = dataset.Cifar10('test')
dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24))) dataset_test = AugmentImageComponent(dataset_test, augmentors)
dataset_test = BatchData(dataset_test, 128) dataset_test = BatchData(dataset_test, 128)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
#step_per_epoch = 20 #step_per_epoch = 20
...@@ -109,19 +115,19 @@ def get_config(): ...@@ -109,19 +115,19 @@ def get_config():
tf.placeholder( tf.placeholder(
tf.int32, shape=(None,), name='label') tf.int32, shape=(None,), name='label')
] ]
input_queue = tf.RandomShuffleQueue( input_queue = tf.FIFOQueue(
100, 50, [x.dtype for x in input_vars], name='queue') 50, [x.dtype for x in input_vars], name='queue')
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-4, learning_rate=1e-1,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 50, decay_steps=dataset_train.size() * 200,
decay_rate=0.1, staircase=True, name='learning_rate') decay_rate=0.1, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.GradientDescentOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
SummaryWriter(), SummaryWriter(),
PeriodicSaver(), PeriodicSaver(),
...@@ -131,6 +137,7 @@ def get_config(): ...@@ -131,6 +137,7 @@ def get_config():
inputs=input_vars, inputs=input_vars,
input_queue=input_queue, input_queue=input_queue,
get_model_func=get_model, get_model_func=get_model,
batched_model_input=False,
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=100, max_epoch=100,
) )
......
...@@ -42,13 +42,12 @@ def get_model(inputs, is_training): ...@@ -42,13 +42,12 @@ def get_model(inputs, is_training):
image, label = inputs image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel image = tf.expand_dims(image, 3) # add a single channel
#if is_training: # slow #if is_training:
## augmentations ## augmentations
#image, label = tf.train.slice_input_producer( #image, label = tf.train.slice_input_producer(
#[image, label], name='slice_queue') #[image, label], shuffle=False, name='slice_queue')
#image = tf.image.random_brightness(image, 0.1) #image, label = tf.train.batch(
#image, label = tf.train.shuffle_batch( #[image, label], BATCH_SIZE, capacity=CAPACITY,
#[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
#num_threads=2, enqueue_many=False) #num_threads=2, enqueue_many=False)
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5) l = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
...@@ -98,7 +97,7 @@ def get_config(): ...@@ -98,7 +97,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
step_per_epoch = 2 #step_per_epoch = 30
#dataset_test = FixedSizeData(dataset_test, 20) #dataset_test = FixedSizeData(dataset_test, 20)
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
import numpy as np import numpy as np
from .base import DataFlow from .base import DataFlow
from imgaug import AugmentorList, Image
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData'] __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'AugmentImageComponent']
class BatchData(DataFlow): class BatchData(DataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
...@@ -101,3 +103,17 @@ class MapData(DataFlow): ...@@ -101,3 +103,17 @@ class MapData(DataFlow):
d = list(dp) d = list(dp)
dp[self.index] = self.func(dp[self.index]) dp[self.index] = self.func(dp[self.index])
yield dp yield dp
def AugmentImageComponent(ds, augmentors, index=0):
"""
Augment the image in each data point
Args:
ds: a DataFlow dataset instance
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
aug = AugmentorList(augmentors)
return MapData(
ds,
lambda img: aug.augment(Image(img)).arr,
index)
...@@ -60,6 +60,7 @@ class AugmentorList(ImageAugmentor): ...@@ -60,6 +60,7 @@ class AugmentorList(ImageAugmentor):
self.augs = augmentors self.augs = augmentors
def _augment(self, img): def _augment(self, img):
assert img.arr.ndim in [2, 3]
img.arr = img.arr.astype('float32') / 255.0 img.arr = img.arr.astype('float32') / 255.0
for aug in self.augs: for aug in self.augs:
aug.augment(img) aug.augment(img)
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
from .base import ImageAugmentor from .base import ImageAugmentor
__all__ = ['RandomCrop'] __all__ = ['RandomCrop', 'CenterCrop', 'Resize']
class RandomCrop(ImageAugmentor): class RandomCrop(ImageAugmentor):
""" Randomly crop the image into a smaller one """
def __init__(self, crop_shape): def __init__(self, crop_shape):
""" """
Randomly crop the image into a smaller one
Args: Args:
crop_shape: shape in (h, w) crop_shape: shape in (h, w)
""" """
...@@ -24,4 +24,29 @@ class RandomCrop(ImageAugmentor): ...@@ -24,4 +24,29 @@ class RandomCrop(ImageAugmentor):
if img.coords: if img.coords:
raise NotImplementedError() raise NotImplementedError()
class CenterCrop(ImageAugmentor):
""" Crop the image in the center"""
def __init__(self, crop_shape):
self._init(locals())
def _augment(self, img):
orig_shape = img.arr.shape
h0 = (orig_shape[0] - self.crop_shape[0]) * 0.5
w0 = (orig_shape[1] - self.crop_shape[1]) * 0.5
img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
if img.coords:
raise NotImplementedError()
class Resize(ImageAugmentor):
"""Resize image to a target size"""
def __init__(self, shape):
"""
Args:
shape: (w, h)
"""
self._init(locals())
def _augment(self, img):
img.arr = cv2.resize(
img.arr, self.shape,
interpolation=cv2.INTER_CUBIC)
...@@ -33,12 +33,13 @@ class Contrast(ImageAugmentor): ...@@ -33,12 +33,13 @@ class Contrast(ImageAugmentor):
r = self._rand_range(*self.factor_range) r = self._rand_range(*self.factor_range)
mean = np.mean(arr, axis=(0,1), keepdims=True) mean = np.mean(arr, axis=(0,1), keepdims=True)
img.arr = (arr - mean) * r + mean img.arr = (arr - mean) * r + mean
img.arr = np.clip(img.arr, 0, 1)
class PerImageWhitening(ImageAugmentor): class PerImageWhitening(ImageAugmentor):
""" """
Linearly scales image to have zero mean and unit norm. Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev x = (x - mean) / adjusted_stddev
where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels)) where adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
""" """
def __init__(self): def __init__(self):
pass pass
...@@ -46,5 +47,5 @@ class PerImageWhitening(ImageAugmentor): ...@@ -46,5 +47,5 @@ class PerImageWhitening(ImageAugmentor):
def _augment(self, img): def _augment(self, img):
mean = np.mean(img.arr, axis=(0,1), keepdims=True) mean = np.mean(img.arr, axis=(0,1), keepdims=True)
std = np.std(img.arr, axis=(0,1), keepdims=True) std = np.std(img.arr, axis=(0,1), keepdims=True)
std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape[:2]))) std = np.maximum(std, 1.0 / np.sqrt(np.prod(img.arr.shape)))
img.arr = (img.arr - mean) / std img.arr = (img.arr - mean) / std
...@@ -36,6 +36,11 @@ class TrainConfig(object): ...@@ -36,6 +36,11 @@ class TrainConfig(object):
with capacity 5 with capacity 5
get_model_func: a function taking `inputs` and `is_training` and get_model_func: a function taking `inputs` and `is_training` and
return a tuple of output list as well as the cost to minimize return a tuple of output list as well as the cost to minimize
batched_model_input: boolean. If yes, `get_model_func` expected batched
input in training. Otherwise, expect single data point in
training, so that you may do pre-processing and batch them
later with batch ops. It's suggested that you do all
preprocessing in dataset as that is usually faster.
step_per_epoch: the number of steps (parameter updates) to perform step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size() in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100 max_epoch: maximum number of epoch to run training. default to 100
...@@ -59,6 +64,7 @@ class TrainConfig(object): ...@@ -59,6 +64,7 @@ class TrainConfig(object):
assert_type(self.input_queue, tf.QueueBase) assert_type(self.input_queue, tf.QueueBase)
assert self.input_queue.dtypes == [x.dtype for x in self.inputs] assert self.input_queue.dtypes == [x.dtype for x in self.inputs]
self.get_model_func = kwargs.pop('get_model_func') self.get_model_func = kwargs.pop('get_model_func')
self.batched_model_input = kwargs.pop('batched_model_input', True)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size())) self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100)) self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0 assert self.step_per_epoch > 0 and self.max_epoch > 0
...@@ -89,11 +95,16 @@ def start_train(config): ...@@ -89,11 +95,16 @@ def start_train(config):
input_queue = config.input_queue input_queue = config.input_queue
callbacks = config.callbacks callbacks = config.callbacks
enqueue_op = input_queue.enqueue(tuple(input_vars)) if config.batched_model_input:
model_inputs = input_queue.dequeue() enqueue_op = input_queue.enqueue(input_vars)
# set dequeue shape model_inputs = input_queue.dequeue()
for qv, v in zip(model_inputs, input_vars): for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
else:
enqueue_op = input_queue.enqueue_many(input_vars)
model_inputs = input_queue.dequeue()
for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape().as_list()[1:])
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True) output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
# build graph # build graph
...@@ -125,6 +136,8 @@ def start_train(config): ...@@ -125,6 +136,8 @@ def start_train(config):
for step in xrange(config.step_per_epoch): for step in xrange(config.step_per_epoch):
if coord.should_stop(): if coord.should_stop():
return return
# TODO if no one uses trigger_step, train_op can be
# faster, see: https://github.com/soumith/convnet-benchmarks/pull/67/files
fetches = [train_op, cost_var] + output_vars + model_inputs fetches = [train_op, cost_var] + output_vars + model_inputs
results = sess.run(fetches) results = sess.run(fetches)
cost = results[1] cost = results[1]
......
...@@ -111,9 +111,9 @@ class CallbackTimeLogger(object): ...@@ -111,9 +111,9 @@ class CallbackTimeLogger(object):
msgs = [] msgs = []
for name, t in self.times: for name, t in self.times:
if t / self.tot > 0.3 and t > 1: if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{}".format(name, t)) msgs.append("{}:{}sec".format(name, t))
logger.info( logger.info(
"Callbacks took {} sec. {}".format( "Callbacks took {} sec in total. {}".format(
self.tot, ' '.join(msgs))) self.tot, ' '.join(msgs)))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment