Commit e8674dca authored by Yuxin Wu's avatar Yuxin Wu

[Keras] Use get_model function instead of letting users create the model

directly (#160)
parent 365c56d2
......@@ -14,6 +14,7 @@ from tensorpack.input_source import QueueInput
from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger
from tensorpack.contrib.keras import KerasModel
from tensorpack.callbacks import ModelSaver
IMAGE_SIZE = 28
......@@ -32,28 +33,41 @@ def get_data():
if __name__ == '__main__':
logger.auto_set_dir()
def model_func(input_tensors):
M = keras.models.Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.InputLayer(
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1],
input_tensor=input_tensors[0]))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
M.add(KL.Flatten())
M.add(KL.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Activation('softmax'))
return M
dataset_train, dataset_test = get_data()
M = KerasModel(M, QueueInput(dataset_train))
# from tensorpack import *
# trainer = SyncMultiGPUTrainerReplicated(2)
M = KerasModel(model_func, QueueInput(dataset_train))
M.compile(
optimizer=tf.train.AdamOptimizer(1e-3),
loss='categorical_crossentropy',
metrics=['accuracy']
metrics=['categorical_accuracy']
)
M.fit(
validation_data=dataset_test,
steps_per_epoch=dataset_train.size(),
callbacks=[
ModelSaver()
]
)
......@@ -3,15 +3,18 @@
# File: keras.py
import tensorflow as tf
from six.moves import zip
import six
from tensorflow import keras
from tensorflow.python.keras import metrics as metrics_module
from ..models.regularize import regularize_cost_from_collection
from ..graph_builder import InputDesc
from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import freeze_collection
# from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from ..callbacks import (
Callback, InferenceRunner, CallbackToHook,
ScalarStats, ModelSaver)
ScalarStats)
from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
......@@ -20,6 +23,30 @@ from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
__all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel']
class KerasModelCaller(object):
"""
Keras model doesn't support vs reuse.
This is hack to mimic reuse.
"""
def __init__(self, get_model):
self.get_model = get_model
self.cached_model = None
def __call__(self, input_tensors):
reuse = tf.get_variable_scope().reuse
if self.cached_model is None:
assert not reuse
self.cached_model = self.get_model(input_tensors)
return self.cached_model.outputs
if reuse:
return self.cached_model.call(input_tensors)
else:
M = self.get_model(input_tensors)
return M.outputs
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
# 1. trainer with isTrain=True
......@@ -44,81 +71,96 @@ class KerasPhaseCallback(Callback):
def setup_keras_trainer(
trainer, model, input,
trainer, get_model, input,
optimizer, loss, metrics=None):
"""
Args:
trainer (SingleCostTrainer):
model (keras.model.Model):
get_model ( -> keras.model.Model):
input (InputSource):
optimizer (tf.tarin.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`.
loss, metrics: list of strings
"""
assert isinstance(optimizer, tf.train.Optimizer), optimizer
inputs_desc = [InputDesc.from_tensor(t) for t in model.inputs]
outputs_desc = [InputDesc.from_tensor(t) for t in model.outputs]
G_tmp = tf.Graph() # we need the model instance to know metadata about inputs/outputs
with G_tmp.as_default():
M_tmp = get_model([None]) # TODO use a proxy with Nones
inputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'input{}'.format(i))
for i, t in enumerate(M_tmp.inputs)]
outputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'output{}'.format(i))
for i, t in enumerate(M_tmp.outputs)]
nr_inputs = len(inputs_desc)
del G_tmp, M_tmp
# clear the collection
del tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)[:]
model_caller = KerasModelCaller(get_model)
def get_cost(*inputs):
assert len(inputs) == len(inputs_desc) + len(outputs_desc), \
"Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(outputs_desc))
ctx = get_current_tower_context()
assert ctx.is_main_training_tower or not ctx.has_own_variables
input_tensors = list(inputs[:nr_inputs])
target_tensors = list(inputs[nr_inputs:])
# Keras check and do weird things if target is a placeholder..
# Use tf.identity so it's not a placeholder.
target_tensors = [tf.identity(t) for t in target_tensors]
input_keras_tensors = [keras.layers.Input(tensor=t) for t in input_tensors]
outputs = model(input_keras_tensors)
M = keras.models.Model(input_tensors, outputs)
with freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]):
# Keras optimizer mistakenly creates TRAINABLE_VARIABLES ...
M.compile(
optimizer=optimizer, loss=loss,
target_tensors=target_tensors,
metrics=metrics)
# BN updates
if ctx.is_training:
for u in M.updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u)
add_moving_summary(tf.identity(M.total_loss, name='total_loss'))
assert len(M.metrics) == len(M.metrics_tensors)
for name, tensor in zip(M.metrics, M.metrics_tensors):
add_moving_summary(tf.identity(tensor, name=name))
# tensorpack requires TRAINABLE_VARIABLES created inside tower
if ctx.is_main_training_tower:
for p in M.weights:
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, p)
return M.total_loss
# TODO mapping between target tensors & output tensors
outputs = model_caller(input_tensors)
if isinstance(outputs, tf.Tensor):
outputs = [outputs]
assert len(outputs) == len(target_tensors), \
"len({}) != len({})".format(str(outputs), str(target_tensors))
assert len(outputs) == len(loss), \
"len({}) != len({})".format(str(outputs), str(loss))
# TODO more losses
with tf.name_scope('keras_loss'):
loss_fn = keras.losses.get(loss[0])
loss_opt = loss_fn(target_tensors[0], outputs[0])
loss_opt = tf.reduce_mean(loss_opt, name=loss[0])
loss_reg = regularize_cost_from_collection()
if loss_reg is not None:
total_loss = tf.add(loss_opt, loss_reg, name='total_loss')
add_moving_summary(loss_opt, loss_reg, total_loss)
else:
add_moving_summary(loss_opt)
total_loss = tf.identity(loss_opt, name='total_loss')
if metrics and (ctx.is_main_training_tower or not ctx.is_training):
# for list: one metric for each output
metric_tensors = []
for oid, metric_name in enumerate(metrics):
output_tensor = outputs[oid]
target_tensor = target_tensors[oid] # TODO may not have the same mapping?
with tf.name_scope('keras_metric'): # TODO ns reuse
metric_fn = metrics_module.get(metric_name)
metric_tensor = metric_fn(target_tensor, output_tensor)
metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name)
# check name conflict here
metric_tensors.append(metric_tensor)
add_moving_summary(*metric_tensors)
return total_loss
trainer.setup_graph(
inputs_desc + outputs_desc,
input,
get_cost,
lambda: optimizer)
if model.uses_learning_phase:
if model_caller.cached_model.uses_learning_phase:
trainer.register_callback(KerasPhaseCallback(True))
class KerasModel(object):
def __init__(self, model, input, trainer=None):
def __init__(self, get_model, input, trainer=None):
"""
Args:
model (keras.model.Model):
get_model ( -> keras.model.Model):
input (InputSource):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
"""
self.model = model
self.get_model = get_model
if trainer is None:
nr_gpu = get_nr_gpu()
if nr_gpu <= 1:
......@@ -130,15 +172,22 @@ class KerasModel(object):
self.input = input
self.trainer = trainer
def compile(self, optimizer, loss, metrics):
def compile(self, optimizer, loss, metrics=None):
"""
Args:
optimizer (tf.train.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`.
loss, metrics: string or list of strings
"""
self._metrics = metrics
if isinstance(loss, six.string_types):
loss = [loss]
if metrics is None:
metrics = []
if isinstance(metrics, six.string_types):
metrics = [metrics]
self._stats_to_inference = loss + metrics
setup_keras_trainer(
self.trainer, model=self.model,
self.trainer, get_model=self.get_model,
input=self.input,
optimizer=optimizer,
loss=loss,
......@@ -151,14 +200,8 @@ class KerasModel(object):
kwargs: same as `self.trainer.train_with_defaults`.
"""
callbacks = kwargs.pop('callbacks', [])
callbacks.extend(self.get_default_callbacks())
if validation_data is not None:
callbacks.append(
InferenceRunner(
validation_data, ScalarStats(self._metrics + ['total_loss'])))
validation_data, ScalarStats(self._stats_to_inference + ['total_loss'])))
self.trainer.train_with_defaults(callbacks=callbacks, **kwargs)
def get_default_callbacks(self):
return [
ModelSaver(keep_checkpoint_every_n_hours=0.2)
]
......@@ -65,11 +65,6 @@ class InputDesc(
return self._cached_placeholder
return self.build_placeholder()
@staticmethod
def from_tensor(t):
return InputDesc(
t.dtype, t.shape.as_list(), t.name[:-2])
@six.add_metaclass(ABCMeta)
class ModelDescBase(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment