Commit 72385a85 authored by Yuxin Wu's avatar Yuxin Wu

add `KerasModel` wrapper (#160)

parent 42f10617
......@@ -24,6 +24,7 @@ Optionally, you can implement the following two methods:
A typical situation is when your DataFlow uses random number generator (RNG). Then you would need to reset the RNG here.
Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you.
You can subclass `RNGDataFlow` to access `self.rng` whose seed has been taken care of.
With a "low-level" DataFlow defined like above, you can then compose it with existing modules (e.g. batching, prefetching, ...).
......
......@@ -16,7 +16,7 @@ from tensorpack.input_source import QueueInput
from tensorpack.callbacks import ModelSaver, InferenceRunner, ScalarStats
from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger
from tensorpack.contrib.keras import setup_keras_trainer
from tensorpack.contrib.keras import KerasModel
IMAGE_SIZE = 28
......@@ -35,8 +35,6 @@ def get_data():
if __name__ == '__main__':
logger.auto_set_dir()
dataset_train, dataset_test = get_data()
M = Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D())
......@@ -50,17 +48,15 @@ if __name__ == '__main__':
M.add(KL.Dense(10, activation=None, kernel_regularizer=regularizers.l2(1e-5)))
M.add(KL.Activation('softmax'))
trainer = SimpleTrainer()
dataset_train, dataset_test = get_data()
setup_keras_trainer(
trainer,
model=M,
input=QueueInput(dataset_train),
M = KerasModel(M, QueueInput(dataset_train))
M.compile(
optimizer=tf.train.AdamOptimizer(1e-3),
loss='categorical_crossentropy',
metrics=['accuracy']
)
trainer.train_with_defaults(
M.fit(
callbacks=[
ModelSaver(),
InferenceRunner(
......
......@@ -11,6 +11,11 @@ from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import freeze_collection
from ..callbacks import Callback, InferenceRunner, CallbackToHook
from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
__all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel']
# Keras needs an extra input if learning_phase is used by the model
......@@ -95,3 +100,38 @@ def setup_keras_trainer(
lambda: optimizer)
if model.uses_learning_phase:
trainer.register_callback(KerasPhaseCallback(True))
class KerasModel(object):
def __init__(self, model, input, trainer=None):
"""
Args:
model (keras.model.Model):
"""
self.model = model
if trainer is None:
nr_gpu = get_nr_gpu()
if nr_gpu <= 1:
trainer = SimpleTrainer()
else:
trainer = SyncMultiGPUTrainerParameterServer(nr_gpu)
assert isinstance(trainer, Trainer), trainer
self.trainer = trainer
self.input = input
def compile(self, optimizer, loss, metrics):
setup_keras_trainer(
self.trainer, model=self.model,
input=self.input,
optimizer=optimizer,
loss=loss,
metrics=metrics)
def fit(self, **kwargs):
callbacks = kwargs.pop('callbacks', [])
callbacks.extend(self.get_default_callbacks())
self.trainer.train_with_defaults(**kwargs)
def get_default_callbacks(self):
return []
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment