Commit 62d54f68 authored by Yuxin Wu's avatar Yuxin Wu

update keras example

parent ccd67e86
......@@ -10,6 +10,7 @@ import os
import sys
import argparse
import keras
import keras.layers as KL
import keras.backend as KB
from keras.models import Sequential
......@@ -27,14 +28,8 @@ from tensorpack.utils.argtools import memoized
IMAGE_SIZE = 28
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label'),
]
@memoized # this is necessary for sonnet/Keras to work under tensorpack
def _build_keras_model(self):
@memoized # this is necessary for sonnet/Keras to work under tensorpack
def get_keras_model():
M = Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D())
......@@ -48,31 +43,30 @@ class Model(ModelDesc):
M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5)))
return M
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = tf.expand_dims(image, 3)
image = image * 2 - 1 # center the pixels values at zero
image = tf.expand_dims(image, 3) * 2 - 1
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32):
M = self._build_keras_model()
M = get_keras_model()
logits = M(image)
prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
# a vector of length B with loss of each sample
# build cost function by tensorflow
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
# for tensorpack validation
wrong = symbolic_functions.prediction_incorrect(logits, label, name='incorrect')
train_error = tf.reduce_mean(wrong, name='train_error')
summary.add_moving_summary(train_error)
wd_cost = tf.add_n(M.losses, name='regularize_loss') # this is how Keras manage regularizers
self.cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(cost, wd_cost, self.cost)
# this is the keras naming
summary.add_param_summary(('conv2d.*/kernel', ['histogram', 'rms']))
summary.add_moving_summary(self.cost)
def _get_optimizer(self):
lr = tf.train.exponential_decay(
......@@ -84,7 +78,7 @@ class Model(ModelDesc):
return tf.train.AdamOptimizer(lr)
# Keras needs an extra input
# Keras needs an extra input if learning_phase is needed
class KerasCallback(Callback):
def __init__(self, isTrain):
self._isTrain = isTrain
......@@ -106,31 +100,23 @@ def get_config():
dataset_train, dataset_test = get_data()
return TrainConfig(
model=Model(),
model=KerasModel(get_keras_model()),
dataflow=dataset_train,
callbacks=[
KerasCallback(1), # for Keras training
KerasCallback(True), # for Keras training
ModelSaver(),
InferenceRunner(
dataset_test,
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')],
extra_hooks=[CallbackToHook(KerasCallback(0))]), # for keras inference
extra_hooks=[CallbackToHook(KerasCallback(False))]), # for keras inference
],
max_epoch=100,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
if config.nr_tower > 1:
SyncMultiGPUTrainer(config).train()
else:
QueueInputTrainer(config).train()
# for multigpu training:
# config.nr_tower = 2
# SyncMultiGPUTrainer(config).train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment