Commit ac02c62f authored by Yuxin Wu's avatar Yuxin Wu

[Keras] use inputs_desc/targets_desc explicitly, to avoid hacks (#160)

parent f1ee1833
......@@ -10,7 +10,7 @@ from tensorflow import keras
KL = keras.layers
from tensorpack.input_source import QueueInput
from tensorpack import InputDesc, QueueInput
from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger
from tensorpack.contrib.keras import KerasModel
......@@ -22,8 +22,7 @@ IMAGE_SIZE = 28
def get_data():
def f(dp):
im = dp[0][:, :, None]
onehot = np.zeros(10, dtype='int32')
onehot[dp[1]] = 1
onehot = np.eye(10)[dp[1]]
return [im, onehot]
train = BatchData(MapData(dataset.Mnist('train'), f), 128)
......@@ -34,11 +33,14 @@ def get_data():
if __name__ == '__main__':
logger.auto_set_dir()
def model_func(input_tensors):
def model_func(inputs):
"""
Keras model has to be created inside this function to be used with tensorpack.
"""
M = keras.models.Sequential()
M.add(KL.InputLayer(
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1],
input_tensor=input_tensors[0]))
# input_tensor have to be used here for tensorpack trainer to function properly.
# Just use inputs[1], inputs[2] if you have multiple inputs.
M.add(KL.InputLayer(input_tensor=inputs[0]))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
......@@ -51,18 +53,19 @@ if __name__ == '__main__':
M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Activation('softmax'))
return M
dataset_train, dataset_test = get_data()
# from tensorpack import *
# trainer = SyncMultiGPUTrainerReplicated(2)
M = KerasModel(model_func, QueueInput(dataset_train))
M = KerasModel(
model_func,
inputs_desc=[InputDesc(tf.float32, [None, IMAGE_SIZE, IMAGE_SIZE, 1], 'images')],
targets_desc=[InputDesc(tf.float32, [None, 10], 'labels')],
input=QueueInput(dataset_train))
M.compile(
optimizer=tf.train.AdamOptimizer(1e-3),
loss='categorical_crossentropy',
metrics=['categorical_accuracy']
metrics='categorical_accuracy'
)
M.fit(
validation_data=dataset_test,
......
......@@ -8,8 +8,8 @@ from tensorflow import keras
from tensorflow.python.keras import metrics as metrics_module
from ..models.regularize import regularize_cost_from_collection
from ..graph_builder import InputDesc
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer, DistributedTrainerBase
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
from ..train.trainers import DistributedTrainerBase
from ..callbacks import (
Callback, InferenceRunner, CallbackToHook,
ScalarStats)
......@@ -45,6 +45,10 @@ class KerasModelCaller(object):
self.cached_model = None
def __call__(self, input_tensors):
"""
Returns:
output tensors of this tower, evaluated with the input tensors.
"""
reuse = tf.get_variable_scope().reuse
if self.cached_model is None:
assert not reuse
......@@ -52,26 +56,13 @@ class KerasModelCaller(object):
return self.cached_model.outputs
if reuse:
# use the cached Keras model to mimic reuse
return self.cached_model.call(input_tensors)
else:
# create new Keras model if not reuse
M = self.get_model(input_tensors)
return M.outputs
def call_virtual(self):
class NoneTensorProxy(object):
def __getitem__(self, index):
return None
def __len__(self):
raise NotImplementedError(
"Do not call `len(inputs)` because it's only a virtual object "
"for the moment! Use `inputs[index]` directly!")
G_tmp = tf.Graph() # we need a model instance to know metadata about inputs/outputs
with G_tmp.as_default():
return self.get_model(NoneTensorProxy())
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
......@@ -97,8 +88,9 @@ class KerasPhaseCallback(Callback):
def setup_keras_trainer(
trainer, get_model, input,
optimizer, loss, metrics):
trainer, get_model,
inputs_desc, targets_desc,
input, optimizer, loss, metrics):
"""
Args:
trainer (SingleCostTrainer):
......@@ -113,17 +105,11 @@ def setup_keras_trainer(
assert isinstance(metrics, list), metrics
model_caller = KerasModelCaller(get_model)
M_tmp = model_caller.call_virtual()
inputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'input{}'.format(i))
for i, t in enumerate(M_tmp.inputs)]
outputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'output{}'.format(i))
for i, t in enumerate(M_tmp.outputs)]
nr_inputs = len(inputs_desc)
def get_cost(*inputs):
assert len(inputs) == len(inputs_desc) + len(outputs_desc), \
"Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(outputs_desc))
assert len(inputs) == len(inputs_desc) + len(targets_desc), \
"Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(targets_desc))
ctx = get_current_tower_context()
input_tensors = list(inputs[:nr_inputs])
target_tensors = list(inputs[nr_inputs:])
......@@ -173,7 +159,7 @@ def setup_keras_trainer(
return total_loss
trainer.setup_graph(
inputs_desc + outputs_desc,
inputs_desc + targets_desc,
input,
get_cost,
lambda: optimizer)
......@@ -182,20 +168,26 @@ def setup_keras_trainer(
class KerasModel(object):
def __init__(self, get_model, input, trainer=None):
def __init__(self, get_model, inputs_desc, targets_desc,
input, trainer=None):
"""
Args:
get_model ( -> keras.model.Model):
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
"""
self.get_model = get_model
self.inputs_desc = inputs_desc
self.targets_desc = targets_desc
if trainer is None:
nr_gpu = get_nr_gpu()
if nr_gpu <= 1:
trainer = SimpleTrainer()
else:
# the default multigpu trainer
trainer = SyncMultiGPUTrainerParameterServer(nr_gpu)
assert isinstance(trainer, Trainer), trainer
assert not isinstance(trainer, DistributedTrainerBase)
......@@ -219,6 +211,7 @@ class KerasModel(object):
self._stats_to_inference = loss + metrics + [TOTAL_LOSS_NAME]
setup_keras_trainer(
self.trainer, get_model=self.get_model,
inputs_desc=self.inputs_desc, targets_desc=self.targets_desc,
input=self.input,
optimizer=optimizer,
loss=loss,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment