Commit 62d54f68 authored by Yuxin Wu's avatar Yuxin Wu

update keras example

parent ccd67e86
...@@ -10,6 +10,7 @@ import os ...@@ -10,6 +10,7 @@ import os
import sys import sys
import argparse import argparse
import keras
import keras.layers as KL import keras.layers as KL
import keras.backend as KB import keras.backend as KB
from keras.models import Sequential from keras.models import Sequential
...@@ -27,52 +28,45 @@ from tensorpack.utils.argtools import memoized ...@@ -27,52 +28,45 @@ from tensorpack.utils.argtools import memoized
IMAGE_SIZE = 28 IMAGE_SIZE = 28
@memoized # this is necessary for sonnet/Keras to work under tensorpack
def get_keras_model():
M = Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
M.add(KL.Flatten())
M.add(KL.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(1e-5)))
M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5)))
return M
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label'), InputDesc(tf.int32, (None,), 'label')]
]
@memoized # this is necessary for sonnet/Keras to work under tensorpack
def _build_keras_model(self):
M = Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
M.add(KL.Flatten())
M.add(KL.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(1e-5)))
M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5)))
return M
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
image = tf.expand_dims(image, 3) image = tf.expand_dims(image, 3) * 2 - 1
image = image * 2 - 1 # center the pixels values at zero
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32): M = get_keras_model()
M = self._build_keras_model() logits = M(image)
logits = M(image) # build cost function by tensorflow
prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
# a vector of length B with loss of each sample
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
# for tensorpack validation
wrong = symbolic_functions.prediction_incorrect(logits, label, name='incorrect') wrong = symbolic_functions.prediction_incorrect(logits, label, name='incorrect')
train_error = tf.reduce_mean(wrong, name='train_error') train_error = tf.reduce_mean(wrong, name='train_error')
summary.add_moving_summary(train_error) summary.add_moving_summary(train_error)
wd_cost = tf.add_n(M.losses, name='regularize_loss') # this is how Keras manage regularizers wd_cost = tf.add_n(M.losses, name='regularize_loss') # this is how Keras manage regularizers
self.cost = tf.add_n([wd_cost, cost], name='total_cost') self.cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(cost, wd_cost, self.cost) summary.add_moving_summary(self.cost)
# this is the keras naming
summary.add_param_summary(('conv2d.*/kernel', ['histogram', 'rms']))
def _get_optimizer(self): def _get_optimizer(self):
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
...@@ -84,7 +78,7 @@ class Model(ModelDesc): ...@@ -84,7 +78,7 @@ class Model(ModelDesc):
return tf.train.AdamOptimizer(lr) return tf.train.AdamOptimizer(lr)
# Keras needs an extra input # Keras needs an extra input if learning_phase is needed
class KerasCallback(Callback): class KerasCallback(Callback):
def __init__(self, isTrain): def __init__(self, isTrain):
self._isTrain = isTrain self._isTrain = isTrain
...@@ -106,31 +100,23 @@ def get_config(): ...@@ -106,31 +100,23 @@ def get_config():
dataset_train, dataset_test = get_data() dataset_train, dataset_test = get_data()
return TrainConfig( return TrainConfig(
model=Model(), model=KerasModel(get_keras_model()),
dataflow=dataset_train, dataflow=dataset_train,
callbacks=[ callbacks=[
KerasCallback(1), # for Keras training KerasCallback(True), # for Keras training
ModelSaver(), ModelSaver(),
InferenceRunner( InferenceRunner(
dataset_test, dataset_test,
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')], [ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')],
extra_hooks=[CallbackToHook(KerasCallback(0))]), # for keras inference extra_hooks=[CallbackToHook(KerasCallback(False))]), # for keras inference
], ],
max_epoch=100, max_epoch=100,
) )
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config() config = get_config()
if args.gpu: QueueInputTrainer(config).train()
config.nr_tower = len(args.gpu.split(',')) # for multigpu training:
if config.nr_tower > 1: # config.nr_tower = 2
SyncMultiGPUTrainer(config).train() # SyncMultiGPUTrainer(config).train()
else:
QueueInputTrainer(config).train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment