Commit bc4c6044 authored by Yuxin Wu's avatar Yuxin Wu

Initial attempt at keras-style training

parent 88b99f38
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-keras-v2.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import os
import sys
import argparse
import keras
from keras.models import Sequential
import keras.layers as KL
from keras import regularizers
from tensorpack.train import SimpleTrainer
from tensorpack.input_source import QueueInput
from tensorpack.callbacks import *
from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger
from tensorpack.contrib.keras import setup_keras_trainer
IMAGE_SIZE = 28
def get_data():
def f(dp):
im = dp[0][:, :, None]
onehot = np.zeros(10, dtype='int32')
onehot[dp[1]] = 1
return [im, onehot]
train = BatchData(MapData(dataset.Mnist('train'), f), 128)
test = BatchData(MapData(dataset.Mnist('test'), f), 256)
return train, test
if __name__ == '__main__':
logger.auto_set_dir()
dataset_train, dataset_test = get_data()
M = Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
M.add(KL.Flatten())
M.add(KL.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(1e-5)))
M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5)))
M.add(KL.Activation('softmax'))
trainer = SimpleTrainer()
setup_keras_trainer(
trainer,
model=M,
input=QueueInput(dataset_train),
optimizer=tf.train.AdamOptimizer(1e-3),
loss='categorical_crossentropy',
metrics=['accuracy']
)
trainer.train_with_defaults(
callbacks=[
ModelSaver(),
InferenceRunner(
dataset_test,
[ScalarStats(['total_loss', 'accuracy'])]),
],
steps_per_epoch=dataset_train.size(),
)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: mnist-keras.py # File: mnist-keras-functional.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
...@@ -12,18 +12,18 @@ import argparse ...@@ -12,18 +12,18 @@ import argparse
import keras import keras
import keras.layers as KL import keras.layers as KL
import keras.backend as KB
from keras.models import Sequential from keras.models import Sequential
from keras import regularizers from keras import regularizers
""" """
This is an mnist example demonstrating how to use Keras models inside tensorpack. This is an mnist example demonstrating how to use Keras symbolic function inside tensorpack.
This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack. This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack.
""" """
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
from tensorpack.contrib.keras import KerasPhaseCallback
IMAGE_SIZE = 28 IMAGE_SIZE = 28
...@@ -78,18 +78,6 @@ class Model(ModelDesc): ...@@ -78,18 +78,6 @@ class Model(ModelDesc):
return tf.train.AdamOptimizer(lr) return tf.train.AdamOptimizer(lr)
# Keras needs an extra input if learning_phase is used by the model
class KerasCallback(Callback):
def __init__(self, isTrain):
assert isinstance(isTrain, bool), isTrain
self._isTrain = isTrain
self._learning_phase = KB.learning_phase()
def _before_run(self, ctx):
return tf.train.SessionRunArgs(
fetches=[], feed_dict={self._learning_phase: int(self._isTrain)})
def get_data(): def get_data():
train = BatchData(dataset.Mnist('train'), 128) train = BatchData(dataset.Mnist('train'), 128)
test = BatchData(dataset.Mnist('test'), 256, remainder=True) test = BatchData(dataset.Mnist('test'), 256, remainder=True)
...@@ -104,12 +92,11 @@ def get_config(): ...@@ -104,12 +92,11 @@ def get_config():
model=Model(), model=Model(),
dataflow=dataset_train, dataflow=dataset_train,
callbacks=[ callbacks=[
KerasCallback(True), # for Keras training KerasPhaseCallback(True), # for Keras training
ModelSaver(), ModelSaver(),
InferenceRunner( InferenceRunner(
dataset_test, dataset_test,
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')], [ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]),
extra_hooks=[CallbackToHook(KerasCallback(False))]), # for keras inference
], ],
max_epoch=100, max_epoch=100,
) )
......
...@@ -61,12 +61,11 @@ class InferenceRunnerBase(Callback): ...@@ -61,12 +61,11 @@ class InferenceRunnerBase(Callback):
Also, InferenceRunner assumes that `trainer.model` exists. Also, InferenceRunner assumes that `trainer.model` exists.
""" """
def __init__(self, input, infs, extra_hooks=None): def __init__(self, input, infs):
""" """
Args: Args:
input (InputSource): the input to use. Must have ``size()``. input (InputSource): the input to use. Must have ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run. infs (list[Inferencer]): list of :class:`Inferencer` to run.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
""" """
self._input_source = input self._input_source = input
if not isinstance(infs, list): if not isinstance(infs, list):
...@@ -82,12 +81,16 @@ class InferenceRunnerBase(Callback): ...@@ -82,12 +81,16 @@ class InferenceRunnerBase(Callback):
raise ValueError("Input used in InferenceRunner must have a size!") raise ValueError("Input used in InferenceRunner must have a size!")
logger.info("InferenceRunner will eval on an InputSource of size {}".format(self._size)) logger.info("InferenceRunner will eval on an InputSource of size {}".format(self._size))
if extra_hooks is None: self._hooks = []
extra_hooks = []
self._extra_hooks = extra_hooks def register_hook(self, hook):
"""
Args:
hook (tf.train.SessionRunHook):
"""
self._hooks.append(hook)
def _before_train(self): def _before_train(self):
self._hooks.extend(self._extra_hooks)
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks) self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._input_callbacks.before_train() self._input_callbacks.before_train()
...@@ -100,7 +103,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -100,7 +103,7 @@ class InferenceRunner(InferenceRunnerBase):
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`. A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
""" """
def __init__(self, input, infs, tower_name='InferenceTower', device=0, extra_hooks=None): def __init__(self, input, infs, tower_name='InferenceTower', device=0):
""" """
Args: Args:
input (InputSource or DataFlow): The :class:`InputSource` to run input (InputSource or DataFlow): The :class:`InputSource` to run
...@@ -115,8 +118,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -115,8 +118,7 @@ class InferenceRunner(InferenceRunnerBase):
assert isinstance(input, InputSource), input assert isinstance(input, InputSource), input
self._tower_name = tower_name self._tower_name = tower_name
self._device = device self._device = device
super(InferenceRunner, self).__init__( super(InferenceRunner, self).__init__(input, infs)
input, infs, extra_hooks=extra_hooks)
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_fetches() out_names = inf.get_fetches()
...@@ -138,11 +140,13 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -138,11 +140,13 @@ class InferenceRunner(InferenceRunnerBase):
self._input_source, self.trainer.tower_func) self._input_source, self.trainer.tower_func)
self._tower_handle = self.trainer.tower_func.towers[-1] self._tower_handle = self.trainer.tower_func.towers[-1]
self._hooks = [self._build_hook(inf) for inf in self.infs] for h in [self._build_hook(inf) for inf in self.infs]:
self.register_hook(h)
# trigger_{step,epoch}, {before,after}_epoch is ignored. # trigger_{step,epoch}, {before,after}_epoch is ignored.
# We assume that InputSource callbacks won't use these methods # We assume that InputSource callbacks won't use these methods
self._input_callbacks = Callbacks(input_callbacks) self._input_callbacks = Callbacks(input_callbacks)
self._hooks.extend(self._input_callbacks.get_hooks()) for h in self._input_callbacks.get_hooks():
self.register_hook(h)
for inf in self.infs: for inf in self.infs:
inf.setup_graph(self.trainer) inf.setup_graph(self.trainer)
...@@ -202,7 +206,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -202,7 +206,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# setup callbacks and hooks # setup callbacks and hooks
self._input_callbacks = Callbacks(input_callbacks) self._input_callbacks = Callbacks(input_callbacks)
# InputSource might have hooks which break us. # TODO InputSource might have hooks which break us.
# e.g. hooks from StagingInput will force the consumption # e.g. hooks from StagingInput will force the consumption
# of nr_tower datapoints in every run. # of nr_tower datapoints in every run.
input_hooks = self._input_callbacks.get_hooks() input_hooks = self._input_callbacks.get_hooks()
...@@ -213,6 +217,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -213,6 +217,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
inf.setup_graph(self.trainer) inf.setup_graph(self.trainer)
self._input_callbacks.setup_graph(self.trainer) self._input_callbacks.setup_graph(self.trainer)
def register_hook(self, h):
raise NotImplementedError("DataParallelInferenceRunner doesn't accept extra hooks!")
class InferencerToHookDataParallel(InferencerToHook): class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size): def __init__(self, inf, fetches, size):
""" """
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: keras.py
import tensorflow as tf
from six.moves import zip
import keras
from ..graph_builder import InputDesc
from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import freeze_collection
from ..callbacks import Callback, InferenceRunner, CallbackToHook
from ..tfutils.summary import add_moving_summary
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
# 1. trainer with isTrain=True
# 2. InferenceRunner with isTrain=False, in the form of hooks
class KerasPhaseCallback(Callback):
def __init__(self, isTrain):
assert isinstance(isTrain, bool), isTrain
self._isTrain = isTrain
self._learning_phase = keras.backend.learning_phase()
def _setup_graph(self):
# HACK
cbs = self.trainer._callbacks.cbs
for cb in cbs:
if isinstance(cb, InferenceRunner):
h = CallbackToHook(KerasPhaseCallback(False))
cb.register_hook(h)
def _before_run(self, ctx):
return tf.train.SessionRunArgs(
fetches=[], feed_dict={self._learning_phase: int(self._isTrain)})
def setup_keras_trainer(
trainer, model, input,
optimizer, loss, metrics=None):
"""
Args:
trainer (SingleCostTrainer):
model (keras.model.Model):
input (InputSource):
optimizer (tf.tarin.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`.
"""
assert isinstance(optimizer, tf.train.Optimizer), optimizer
inputs_desc = [InputDesc.from_tensor(t) for t in model.inputs]
outputs_desc = [InputDesc.from_tensor(t) for t in model.outputs]
nr_inputs = len(inputs_desc)
# clear the collection
del tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)[:]
def get_cost(*inputs):
ctx = get_current_tower_context()
assert ctx.is_main_training_tower or not ctx.has_own_variables
input_tensors = list(inputs[:nr_inputs])
target_tensors = list(inputs[nr_inputs:])
# Keras check and do weird things if target is a placeholder..
# Use tf.identity so it's not a placeholder.
target_tensors = [tf.identity(t) for t in target_tensors]
input_keras_tensors = [keras.layers.Input(tensor=t) for t in input_tensors]
outputs = model(input_keras_tensors)
M = keras.models.Model(input_tensors, outputs)
with freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]):
# Keras optimizer mistakenly creates TRAINABLE_VARIABLES ...
M.compile(
optimizer=optimizer, loss=loss,
target_tensors=target_tensors,
metrics=metrics)
add_moving_summary(tf.identity(M.total_loss, name='total_loss'))
assert len(M.metrics) == len(M.metrics_tensors)
for name, tensor in zip(M.metrics, M.metrics_tensors):
add_moving_summary(tf.identity(tensor, name=name))
# tensorpack requires TRAINABLE_VARIABLES created inside tower
if ctx.is_main_training_tower:
for p in M.weights:
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, p)
return M.total_loss
trainer.setup_graph(
inputs_desc + outputs_desc,
input,
get_cost,
lambda: optimizer)
if model.uses_learning_phase:
trainer.register_callback(KerasPhaseCallback(True))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment