Commit 72385a85 authored by Yuxin Wu's avatar Yuxin Wu

add `KerasModel` wrapper (#160)

parent 42f10617
...@@ -24,6 +24,7 @@ Optionally, you can implement the following two methods: ...@@ -24,6 +24,7 @@ Optionally, you can implement the following two methods:
A typical situation is when your DataFlow uses random number generator (RNG). Then you would need to reset the RNG here. A typical situation is when your DataFlow uses random number generator (RNG). Then you would need to reset the RNG here.
Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you. Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you.
You can subclass `RNGDataFlow` to access `self.rng` whose seed has been taken care of.
With a "low-level" DataFlow defined like above, you can then compose it with existing modules (e.g. batching, prefetching, ...). With a "low-level" DataFlow defined like above, you can then compose it with existing modules (e.g. batching, prefetching, ...).
......
...@@ -16,7 +16,7 @@ from tensorpack.input_source import QueueInput ...@@ -16,7 +16,7 @@ from tensorpack.input_source import QueueInput
from tensorpack.callbacks import ModelSaver, InferenceRunner, ScalarStats from tensorpack.callbacks import ModelSaver, InferenceRunner, ScalarStats
from tensorpack.dataflow import dataset, BatchData, MapData from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.contrib.keras import setup_keras_trainer from tensorpack.contrib.keras import KerasModel
IMAGE_SIZE = 28 IMAGE_SIZE = 28
...@@ -35,8 +35,6 @@ def get_data(): ...@@ -35,8 +35,6 @@ def get_data():
if __name__ == '__main__': if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
dataset_train, dataset_test = get_data()
M = Sequential() M = Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D()) M.add(KL.MaxPooling2D())
...@@ -50,17 +48,15 @@ if __name__ == '__main__': ...@@ -50,17 +48,15 @@ if __name__ == '__main__':
M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5))) M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5)))
M.add(KL.Activation('softmax')) M.add(KL.Activation('softmax'))
trainer = SimpleTrainer() dataset_train, dataset_test = get_data()
setup_keras_trainer( M = KerasModel(M, QueueInput(dataset_train))
trainer, M.compile(
model=M,
input=QueueInput(dataset_train),
optimizer=tf.train.AdamOptimizer(1e-3), optimizer=tf.train.AdamOptimizer(1e-3),
loss='categorical_crossentropy', loss='categorical_crossentropy',
metrics=['accuracy'] metrics=['accuracy']
) )
trainer.train_with_defaults( M.fit(
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
InferenceRunner( InferenceRunner(
......
...@@ -11,6 +11,11 @@ from ..tfutils.tower import get_current_tower_context ...@@ -11,6 +11,11 @@ from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import freeze_collection from ..tfutils.collection import freeze_collection
from ..callbacks import Callback, InferenceRunner, CallbackToHook from ..callbacks import Callback, InferenceRunner, CallbackToHook
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
__all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel']
# Keras needs an extra input if learning_phase is used by the model # Keras needs an extra input if learning_phase is used by the model
...@@ -95,3 +100,38 @@ def setup_keras_trainer( ...@@ -95,3 +100,38 @@ def setup_keras_trainer(
lambda: optimizer) lambda: optimizer)
if model.uses_learning_phase: if model.uses_learning_phase:
trainer.register_callback(KerasPhaseCallback(True)) trainer.register_callback(KerasPhaseCallback(True))
class KerasModel(object):
def __init__(self, model, input, trainer=None):
"""
Args:
model (keras.model.Model):
"""
self.model = model
if trainer is None:
nr_gpu = get_nr_gpu()
if nr_gpu <= 1:
trainer = SimpleTrainer()
else:
trainer = SyncMultiGPUTrainerParameterServer(nr_gpu)
assert isinstance(trainer, Trainer), trainer
self.trainer = trainer
self.input = input
def compile(self, optimizer, loss, metrics):
setup_keras_trainer(
self.trainer, model=self.model,
input=self.input,
optimizer=optimizer,
loss=loss,
metrics=metrics)
def fit(self, **kwargs):
callbacks = kwargs.pop('callbacks', [])
callbacks.extend(self.get_default_callbacks())
self.trainer.train_with_defaults(**kwargs)
def get_default_callbacks(self):
return []
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment