Commit e8674dca authored by Yuxin Wu's avatar Yuxin Wu

[Keras] Use get_model function instead of letting users create the model

directly (#160)
parent 365c56d2
...@@ -14,6 +14,7 @@ from tensorpack.input_source import QueueInput ...@@ -14,6 +14,7 @@ from tensorpack.input_source import QueueInput
from tensorpack.dataflow import dataset, BatchData, MapData from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.contrib.keras import KerasModel from tensorpack.contrib.keras import KerasModel
from tensorpack.callbacks import ModelSaver
IMAGE_SIZE = 28 IMAGE_SIZE = 28
...@@ -32,28 +33,41 @@ def get_data(): ...@@ -32,28 +33,41 @@ def get_data():
if __name__ == '__main__': if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
M = keras.models.Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same')) def model_func(input_tensors):
M.add(KL.MaxPooling2D()) M = keras.models.Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.InputLayer(
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1],
M.add(KL.MaxPooling2D()) input_tensor=input_tensors[0]))
M.add(KL.Conv2D(32, 3, padding='same', activation='relu')) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Flatten()) M.add(KL.MaxPooling2D())
M.add(KL.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-5))) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Dropout(0.5)) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5))) M.add(KL.MaxPooling2D())
M.add(KL.Activation('softmax')) M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
M.add(KL.Flatten())
M.add(KL.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Activation('softmax'))
return M
dataset_train, dataset_test = get_data() dataset_train, dataset_test = get_data()
M = KerasModel(M, QueueInput(dataset_train)) # from tensorpack import *
# trainer = SyncMultiGPUTrainerReplicated(2)
M = KerasModel(model_func, QueueInput(dataset_train))
M.compile( M.compile(
optimizer=tf.train.AdamOptimizer(1e-3), optimizer=tf.train.AdamOptimizer(1e-3),
loss='categorical_crossentropy', loss='categorical_crossentropy',
metrics=['accuracy'] metrics=['categorical_accuracy']
) )
M.fit( M.fit(
validation_data=dataset_test, validation_data=dataset_test,
steps_per_epoch=dataset_train.size(), steps_per_epoch=dataset_train.size(),
callbacks=[
ModelSaver()
]
) )
...@@ -3,15 +3,18 @@ ...@@ -3,15 +3,18 @@
# File: keras.py # File: keras.py
import tensorflow as tf import tensorflow as tf
from six.moves import zip import six
from tensorflow import keras from tensorflow import keras
from tensorflow.python.keras import metrics as metrics_module
from ..models.regularize import regularize_cost_from_collection
from ..graph_builder import InputDesc from ..graph_builder import InputDesc
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import freeze_collection # from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from ..callbacks import ( from ..callbacks import (
Callback, InferenceRunner, CallbackToHook, Callback, InferenceRunner, CallbackToHook,
ScalarStats, ModelSaver) ScalarStats)
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu from ..utils.gpu import get_nr_gpu
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
...@@ -20,6 +23,30 @@ from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer ...@@ -20,6 +23,30 @@ from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
__all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel'] __all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel']
class KerasModelCaller(object):
"""
Keras model doesn't support vs reuse.
This is hack to mimic reuse.
"""
def __init__(self, get_model):
self.get_model = get_model
self.cached_model = None
def __call__(self, input_tensors):
reuse = tf.get_variable_scope().reuse
if self.cached_model is None:
assert not reuse
self.cached_model = self.get_model(input_tensors)
return self.cached_model.outputs
if reuse:
return self.cached_model.call(input_tensors)
else:
M = self.get_model(input_tensors)
return M.outputs
# Keras needs an extra input if learning_phase is used by the model # Keras needs an extra input if learning_phase is used by the model
# This cb will be used by # This cb will be used by
# 1. trainer with isTrain=True # 1. trainer with isTrain=True
...@@ -44,81 +71,96 @@ class KerasPhaseCallback(Callback): ...@@ -44,81 +71,96 @@ class KerasPhaseCallback(Callback):
def setup_keras_trainer( def setup_keras_trainer(
trainer, model, input, trainer, get_model, input,
optimizer, loss, metrics=None): optimizer, loss, metrics=None):
""" """
Args: Args:
trainer (SingleCostTrainer): trainer (SingleCostTrainer):
model (keras.model.Model): get_model ( -> keras.model.Model):
input (InputSource): input (InputSource):
optimizer (tf.tarin.Optimizer): optimizer (tf.tarin.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`. loss, metrics: list of strings
""" """
assert isinstance(optimizer, tf.train.Optimizer), optimizer assert isinstance(optimizer, tf.train.Optimizer), optimizer
inputs_desc = [InputDesc.from_tensor(t) for t in model.inputs]
outputs_desc = [InputDesc.from_tensor(t) for t in model.outputs]
nr_inputs = len(inputs_desc)
# clear the collection G_tmp = tf.Graph() # we need the model instance to know metadata about inputs/outputs
del tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)[:] with G_tmp.as_default():
M_tmp = get_model([None]) # TODO use a proxy with Nones
inputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'input{}'.format(i))
for i, t in enumerate(M_tmp.inputs)]
outputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'output{}'.format(i))
for i, t in enumerate(M_tmp.outputs)]
nr_inputs = len(inputs_desc)
del G_tmp, M_tmp
model_caller = KerasModelCaller(get_model)
def get_cost(*inputs): def get_cost(*inputs):
assert len(inputs) == len(inputs_desc) + len(outputs_desc), \
"Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(outputs_desc))
ctx = get_current_tower_context() ctx = get_current_tower_context()
assert ctx.is_main_training_tower or not ctx.has_own_variables
input_tensors = list(inputs[:nr_inputs]) input_tensors = list(inputs[:nr_inputs])
target_tensors = list(inputs[nr_inputs:]) target_tensors = list(inputs[nr_inputs:])
# TODO mapping between target tensors & output tensors
# Keras check and do weird things if target is a placeholder..
# Use tf.identity so it's not a placeholder. outputs = model_caller(input_tensors)
target_tensors = [tf.identity(t) for t in target_tensors]
if isinstance(outputs, tf.Tensor):
input_keras_tensors = [keras.layers.Input(tensor=t) for t in input_tensors] outputs = [outputs]
outputs = model(input_keras_tensors) assert len(outputs) == len(target_tensors), \
"len({}) != len({})".format(str(outputs), str(target_tensors))
M = keras.models.Model(input_tensors, outputs) assert len(outputs) == len(loss), \
"len({}) != len({})".format(str(outputs), str(loss))
with freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]):
# Keras optimizer mistakenly creates TRAINABLE_VARIABLES ... # TODO more losses
M.compile( with tf.name_scope('keras_loss'):
optimizer=optimizer, loss=loss, loss_fn = keras.losses.get(loss[0])
target_tensors=target_tensors, loss_opt = loss_fn(target_tensors[0], outputs[0])
metrics=metrics) loss_opt = tf.reduce_mean(loss_opt, name=loss[0])
# BN updates loss_reg = regularize_cost_from_collection()
if ctx.is_training: if loss_reg is not None:
for u in M.updates: total_loss = tf.add(loss_opt, loss_reg, name='total_loss')
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u) add_moving_summary(loss_opt, loss_reg, total_loss)
else:
add_moving_summary(tf.identity(M.total_loss, name='total_loss')) add_moving_summary(loss_opt)
total_loss = tf.identity(loss_opt, name='total_loss')
assert len(M.metrics) == len(M.metrics_tensors)
for name, tensor in zip(M.metrics, M.metrics_tensors): if metrics and (ctx.is_main_training_tower or not ctx.is_training):
add_moving_summary(tf.identity(tensor, name=name)) # for list: one metric for each output
# tensorpack requires TRAINABLE_VARIABLES created inside tower metric_tensors = []
if ctx.is_main_training_tower: for oid, metric_name in enumerate(metrics):
for p in M.weights: output_tensor = outputs[oid]
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, p) target_tensor = target_tensors[oid] # TODO may not have the same mapping?
return M.total_loss with tf.name_scope('keras_metric'): # TODO ns reuse
metric_fn = metrics_module.get(metric_name)
metric_tensor = metric_fn(target_tensor, output_tensor)
metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name)
# check name conflict here
metric_tensors.append(metric_tensor)
add_moving_summary(*metric_tensors)
return total_loss
trainer.setup_graph( trainer.setup_graph(
inputs_desc + outputs_desc, inputs_desc + outputs_desc,
input, input,
get_cost, get_cost,
lambda: optimizer) lambda: optimizer)
if model.uses_learning_phase: if model_caller.cached_model.uses_learning_phase:
trainer.register_callback(KerasPhaseCallback(True)) trainer.register_callback(KerasPhaseCallback(True))
class KerasModel(object): class KerasModel(object):
def __init__(self, model, input, trainer=None): def __init__(self, get_model, input, trainer=None):
""" """
Args: Args:
model (keras.model.Model): get_model ( -> keras.model.Model):
input (InputSource): input (InputSource):
trainer (Trainer): the default will check the number of available trainer (Trainer): the default will check the number of available
GPUs and use them all. GPUs and use them all.
""" """
self.model = model self.get_model = get_model
if trainer is None: if trainer is None:
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
if nr_gpu <= 1: if nr_gpu <= 1:
...@@ -130,15 +172,22 @@ class KerasModel(object): ...@@ -130,15 +172,22 @@ class KerasModel(object):
self.input = input self.input = input
self.trainer = trainer self.trainer = trainer
def compile(self, optimizer, loss, metrics): def compile(self, optimizer, loss, metrics=None):
""" """
Args: Args:
optimizer (tf.train.Optimizer): optimizer (tf.train.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`. loss, metrics: string or list of strings
""" """
self._metrics = metrics if isinstance(loss, six.string_types):
loss = [loss]
if metrics is None:
metrics = []
if isinstance(metrics, six.string_types):
metrics = [metrics]
self._stats_to_inference = loss + metrics
setup_keras_trainer( setup_keras_trainer(
self.trainer, model=self.model, self.trainer, get_model=self.get_model,
input=self.input, input=self.input,
optimizer=optimizer, optimizer=optimizer,
loss=loss, loss=loss,
...@@ -151,14 +200,8 @@ class KerasModel(object): ...@@ -151,14 +200,8 @@ class KerasModel(object):
kwargs: same as `self.trainer.train_with_defaults`. kwargs: same as `self.trainer.train_with_defaults`.
""" """
callbacks = kwargs.pop('callbacks', []) callbacks = kwargs.pop('callbacks', [])
callbacks.extend(self.get_default_callbacks())
if validation_data is not None: if validation_data is not None:
callbacks.append( callbacks.append(
InferenceRunner( InferenceRunner(
validation_data, ScalarStats(self._metrics + ['total_loss']))) validation_data, ScalarStats(self._stats_to_inference + ['total_loss'])))
self.trainer.train_with_defaults(callbacks=callbacks, **kwargs) self.trainer.train_with_defaults(callbacks=callbacks, **kwargs)
def get_default_callbacks(self):
return [
ModelSaver(keep_checkpoint_every_n_hours=0.2)
]
...@@ -65,11 +65,6 @@ class InputDesc( ...@@ -65,11 +65,6 @@ class InputDesc(
return self._cached_placeholder return self._cached_placeholder
return self.build_placeholder() return self.build_placeholder()
@staticmethod
def from_tensor(t):
return InputDesc(
t.dtype, t.shape.as_list(), t.name[:-2])
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class ModelDescBase(object): class ModelDescBase(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment