Commit 7e2be137 authored by Yuxin Wu's avatar Yuxin Wu

add Keras example (#160)

parent db30c255
...@@ -71,7 +71,7 @@ class Model(ModelDesc): ...@@ -71,7 +71,7 @@ class Model(ModelDesc):
logits = (LinearWrap(image) # the starting brace is only for line-breaking logits = (LinearWrap(image) # the starting brace is only for line-breaking
.Conv2D('conv0') .Conv2D('conv0')
.MaxPooling('pool0', 2) .MaxPooling('pool0', 2)
.Conv2D('conv1', padding='SAME') .Conv2D('conv1')
.Conv2D('conv2') .Conv2D('conv2')
.MaxPooling('pool1', 2) .MaxPooling('pool1', 2)
.Conv2D('conv3') .Conv2D('conv3')
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-keras.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import os
import sys
import argparse
import keras.layers as KL
import keras.backend as KB
from keras.models import Sequential
from keras import regularizers
"""
This is an mnist example demonstrating how to use Keras models inside tensorpack.
This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack.
"""
from tensorpack import *
from tensorpack.utils.argtools import memoized
IMAGE_SIZE = 28
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputDesc(tf.int32, (None,), 'label'),
]
@memoized # this is necessary for Keras to work under tensorpack
def _build_keras_model(self):
M = Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
M.add(KL.Flatten())
M.add(KL.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(1e-5)))
M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5)))
return M
def _build_graph(self, inputs):
image, label = inputs
image = tf.expand_dims(image, 3)
image = image * 2 - 1 # center the pixels values at zero
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32):
M = self._build_keras_model()
logits = M(image)
prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
# a vector of length B with loss of each sample
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
wrong = symbolic_functions.prediction_incorrect(logits, label, name='incorrect')
train_error = tf.reduce_mean(wrong, name='train_error')
summary.add_moving_summary(train_error)
wd_cost = tf.add_n(M.losses, name='regularize_loss') # this is how Keras manage regularizers
self.cost = tf.add_n([wd_cost, cost], name='total_cost')
summary.add_moving_summary(cost, wd_cost, self.cost)
# this is the keras naming
summary.add_param_summary(('conv2d.*/kernel', ['histogram', 'rms']))
def _get_optimizer(self):
lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=get_global_step_var(),
decay_steps=468 * 10,
decay_rate=0.3, staircase=True, name='learning_rate')
tf.summary.scalar('lr', lr)
return tf.train.AdamOptimizer(lr)
class KerasCallback(Callback):
def __init__(self, isTrain):
self._isTrain = isTrain
self._learning_phase = KB.learning_phase()
def _before_run(self, ctx):
return tf.train.SessionRunArgs(
fetches=[], feed_dict={self._learning_phase: int(self._isTrain)})
def get_data():
train = BatchData(dataset.Mnist('train'), 128)
test = BatchData(dataset.Mnist('test'), 256, remainder=True)
return train, test
def get_config():
logger.auto_set_dir()
dataset_train, dataset_test = get_data()
return TrainConfig(
model=Model(),
dataflow=dataset_train,
callbacks=[
KerasCallback(1), # for Keras training
ModelSaver(),
InferenceRunner(
dataset_test,
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')],
extra_hooks=[CallbackToHook(KerasCallback(0))]), # for keras inference
],
max_epoch=100,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = get_nr_gpu()
if config.nr_tower > 1:
SyncMultiGPUTrainer(config).train()
else:
QueueInputTrainer(config).train()
...@@ -56,7 +56,7 @@ def summary_inferencer(trainer, infs): ...@@ -56,7 +56,7 @@ def summary_inferencer(trainer, infs):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InferenceRunnerBase(Callback): class InferenceRunnerBase(Callback):
""" Base methods for inference runner""" """ Base methods for inference runner"""
def __init__(self, input, infs, input_names=None, prefix=''): def __init__(self, input, infs, input_names=None, prefix='', extra_hooks=None):
""" """
Args: Args:
input (InputData): the input to use. Must have ``size()``. input (InputData): the input to use. Must have ``size()``.
...@@ -64,6 +64,7 @@ class InferenceRunnerBase(Callback): ...@@ -64,6 +64,7 @@ class InferenceRunnerBase(Callback):
input_names (list): must be a subset of the names in InputDesc. input_names (list): must be a subset of the names in InputDesc.
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used. differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list): extra ``SessionRunHook`` to run with the evaluation.
""" """
self._input_data = input self._input_data = input
if not isinstance(infs, list): if not isinstance(infs, list):
...@@ -82,6 +83,10 @@ class InferenceRunnerBase(Callback): ...@@ -82,6 +83,10 @@ class InferenceRunnerBase(Callback):
raise ValueError("Input used in InferenceRunner must have a size!") raise ValueError("Input used in InferenceRunner must have a size!")
self._prefix = prefix self._prefix = prefix
if extra_hooks is None:
extra_hooks = []
self._extra_hooks = extra_hooks
def _setup_input_names(self): def _setup_input_names(self):
# just use all the placeholders, if input_name is None # just use all the placeholders, if input_name is None
if self.input_names is None: if self.input_names is None:
...@@ -111,6 +116,7 @@ class InferenceRunnerBase(Callback): ...@@ -111,6 +116,7 @@ class InferenceRunnerBase(Callback):
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
def _before_train(self): def _before_train(self):
self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
def _get_tensors_maybe_in_tower(self, names): def _get_tensors_maybe_in_tower(self, names):
...@@ -148,7 +154,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -148,7 +154,7 @@ class InferenceRunner(InferenceRunnerBase):
:class:`DataFlow`. :class:`DataFlow`.
""" """
def __init__(self, ds, infs, input_names=None): def __init__(self, ds, infs, input_names=None, extra_hooks=None):
""" """
Args: Args:
ds (DataFlow): the DataFlow to run inferencer on. ds (DataFlow): the DataFlow to run inferencer on.
...@@ -158,7 +164,8 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -158,7 +164,8 @@ class InferenceRunner(InferenceRunnerBase):
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
input = FeedInput(ds) input = FeedInput(ds)
super(InferenceRunner, self).__init__(input, infs, input_names, prefix='') super(InferenceRunner, self).__init__(
input, infs, input_names, prefix='', extra_hooks=extra_hooks)
def _find_input_tensors(self): def _find_input_tensors(self):
return self.trainer.model.get_reused_placehdrs() return self.trainer.model.get_reused_placehdrs()
...@@ -178,7 +185,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -178,7 +185,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
pipeline. pipeline.
""" """
def __init__(self, input, infs, input_names=None, prefix=''): def __init__(self, input, infs, input_names=None, prefix='', extra_hooks=None):
""" """
Args: Args:
input (TensorInput): the input to use. Must have ``size()``. input (TensorInput): the input to use. Must have ``size()``.
...@@ -188,7 +195,8 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -188,7 +195,8 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
differently if more than one :class:`FeedfreeInferenceRunner` are used. differently if more than one :class:`FeedfreeInferenceRunner` are used.
""" """
assert isinstance(input, TensorInput), input assert isinstance(input, TensorInput), input
super(FeedfreeInferenceRunner, self).__init__(input, infs, input_names, prefix) super(FeedfreeInferenceRunner, self).__init__(
input, infs, input_names, prefix=prefix, extra_hooks=extra_hooks)
def _setup_input_names(self): def _setup_input_names(self):
super(FeedfreeInferenceRunner, self)._setup_input_names() super(FeedfreeInferenceRunner, self)._setup_input_names()
......
...@@ -121,5 +121,6 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None): ...@@ -121,5 +121,6 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!") log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!")
config.predict_tower = predict_tower config.predict_tower = predict_tower
assert len(config.tower) == 1, \ assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "Got nr_tower={}, but QueueInputTrainer doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(config.tower))
return SimpleFeedfreeTrainer(config) return SimpleFeedfreeTrainer(config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment