Commit ac02c62f authored by Yuxin Wu's avatar Yuxin Wu

[Keras] use inputs_desc/targets_desc explicitly, to avoid hacks (#160)

parent f1ee1833
...@@ -10,7 +10,7 @@ from tensorflow import keras ...@@ -10,7 +10,7 @@ from tensorflow import keras
KL = keras.layers KL = keras.layers
from tensorpack.input_source import QueueInput from tensorpack import InputDesc, QueueInput
from tensorpack.dataflow import dataset, BatchData, MapData from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.contrib.keras import KerasModel from tensorpack.contrib.keras import KerasModel
...@@ -22,8 +22,7 @@ IMAGE_SIZE = 28 ...@@ -22,8 +22,7 @@ IMAGE_SIZE = 28
def get_data(): def get_data():
def f(dp): def f(dp):
im = dp[0][:, :, None] im = dp[0][:, :, None]
onehot = np.zeros(10, dtype='int32') onehot = np.eye(10)[dp[1]]
onehot[dp[1]] = 1
return [im, onehot] return [im, onehot]
train = BatchData(MapData(dataset.Mnist('train'), f), 128) train = BatchData(MapData(dataset.Mnist('train'), f), 128)
...@@ -34,11 +33,14 @@ def get_data(): ...@@ -34,11 +33,14 @@ def get_data():
if __name__ == '__main__': if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
def model_func(input_tensors): def model_func(inputs):
"""
Keras model has to be created inside this function to be used with tensorpack.
"""
M = keras.models.Sequential() M = keras.models.Sequential()
M.add(KL.InputLayer( # input_tensor have to be used here for tensorpack trainer to function properly.
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], # Just use inputs[1], inputs[2] if you have multiple inputs.
input_tensor=input_tensors[0])) M.add(KL.InputLayer(input_tensor=inputs[0]))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D()) M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
...@@ -51,18 +53,19 @@ if __name__ == '__main__': ...@@ -51,18 +53,19 @@ if __name__ == '__main__':
M.add(KL.Dropout(0.5)) M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5))) M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Activation('softmax')) M.add(KL.Activation('softmax'))
return M return M
dataset_train, dataset_test = get_data() dataset_train, dataset_test = get_data()
# from tensorpack import * M = KerasModel(
# trainer = SyncMultiGPUTrainerReplicated(2) model_func,
M = KerasModel(model_func, QueueInput(dataset_train)) inputs_desc=[InputDesc(tf.float32, [None, IMAGE_SIZE, IMAGE_SIZE, 1], 'images')],
targets_desc=[InputDesc(tf.float32, [None, 10], 'labels')],
input=QueueInput(dataset_train))
M.compile( M.compile(
optimizer=tf.train.AdamOptimizer(1e-3), optimizer=tf.train.AdamOptimizer(1e-3),
loss='categorical_crossentropy', loss='categorical_crossentropy',
metrics=['categorical_accuracy'] metrics='categorical_accuracy'
) )
M.fit( M.fit(
validation_data=dataset_test, validation_data=dataset_test,
......
...@@ -8,8 +8,8 @@ from tensorflow import keras ...@@ -8,8 +8,8 @@ from tensorflow import keras
from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import metrics as metrics_module
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
from ..graph_builder import InputDesc from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer, DistributedTrainerBase from ..train.trainers import DistributedTrainerBase
from ..callbacks import ( from ..callbacks import (
Callback, InferenceRunner, CallbackToHook, Callback, InferenceRunner, CallbackToHook,
ScalarStats) ScalarStats)
...@@ -45,6 +45,10 @@ class KerasModelCaller(object): ...@@ -45,6 +45,10 @@ class KerasModelCaller(object):
self.cached_model = None self.cached_model = None
def __call__(self, input_tensors): def __call__(self, input_tensors):
"""
Returns:
output tensors of this tower, evaluated with the input tensors.
"""
reuse = tf.get_variable_scope().reuse reuse = tf.get_variable_scope().reuse
if self.cached_model is None: if self.cached_model is None:
assert not reuse assert not reuse
...@@ -52,26 +56,13 @@ class KerasModelCaller(object): ...@@ -52,26 +56,13 @@ class KerasModelCaller(object):
return self.cached_model.outputs return self.cached_model.outputs
if reuse: if reuse:
# use the cached Keras model to mimic reuse
return self.cached_model.call(input_tensors) return self.cached_model.call(input_tensors)
else: else:
# create new Keras model if not reuse
M = self.get_model(input_tensors) M = self.get_model(input_tensors)
return M.outputs return M.outputs
def call_virtual(self):
class NoneTensorProxy(object):
def __getitem__(self, index):
return None
def __len__(self):
raise NotImplementedError(
"Do not call `len(inputs)` because it's only a virtual object "
"for the moment! Use `inputs[index]` directly!")
G_tmp = tf.Graph() # we need a model instance to know metadata about inputs/outputs
with G_tmp.as_default():
return self.get_model(NoneTensorProxy())
# Keras needs an extra input if learning_phase is used by the model # Keras needs an extra input if learning_phase is used by the model
# This cb will be used by # This cb will be used by
...@@ -97,8 +88,9 @@ class KerasPhaseCallback(Callback): ...@@ -97,8 +88,9 @@ class KerasPhaseCallback(Callback):
def setup_keras_trainer( def setup_keras_trainer(
trainer, get_model, input, trainer, get_model,
optimizer, loss, metrics): inputs_desc, targets_desc,
input, optimizer, loss, metrics):
""" """
Args: Args:
trainer (SingleCostTrainer): trainer (SingleCostTrainer):
...@@ -113,17 +105,11 @@ def setup_keras_trainer( ...@@ -113,17 +105,11 @@ def setup_keras_trainer(
assert isinstance(metrics, list), metrics assert isinstance(metrics, list), metrics
model_caller = KerasModelCaller(get_model) model_caller = KerasModelCaller(get_model)
M_tmp = model_caller.call_virtual()
inputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'input{}'.format(i))
for i, t in enumerate(M_tmp.inputs)]
outputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'output{}'.format(i))
for i, t in enumerate(M_tmp.outputs)]
nr_inputs = len(inputs_desc) nr_inputs = len(inputs_desc)
def get_cost(*inputs): def get_cost(*inputs):
assert len(inputs) == len(inputs_desc) + len(outputs_desc), \ assert len(inputs) == len(inputs_desc) + len(targets_desc), \
"Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(outputs_desc)) "Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(targets_desc))
ctx = get_current_tower_context() ctx = get_current_tower_context()
input_tensors = list(inputs[:nr_inputs]) input_tensors = list(inputs[:nr_inputs])
target_tensors = list(inputs[nr_inputs:]) target_tensors = list(inputs[nr_inputs:])
...@@ -173,7 +159,7 @@ def setup_keras_trainer( ...@@ -173,7 +159,7 @@ def setup_keras_trainer(
return total_loss return total_loss
trainer.setup_graph( trainer.setup_graph(
inputs_desc + outputs_desc, inputs_desc + targets_desc,
input, input,
get_cost, get_cost,
lambda: optimizer) lambda: optimizer)
...@@ -182,20 +168,26 @@ def setup_keras_trainer( ...@@ -182,20 +168,26 @@ def setup_keras_trainer(
class KerasModel(object): class KerasModel(object):
def __init__(self, get_model, input, trainer=None): def __init__(self, get_model, inputs_desc, targets_desc,
input, trainer=None):
""" """
Args: Args:
get_model ( -> keras.model.Model): get_model ( -> keras.model.Model):
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource): input (InputSource):
trainer (Trainer): the default will check the number of available trainer (Trainer): the default will check the number of available
GPUs and use them all. GPUs and use them all.
""" """
self.get_model = get_model self.get_model = get_model
self.inputs_desc = inputs_desc
self.targets_desc = targets_desc
if trainer is None: if trainer is None:
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
if nr_gpu <= 1: if nr_gpu <= 1:
trainer = SimpleTrainer() trainer = SimpleTrainer()
else: else:
# the default multigpu trainer
trainer = SyncMultiGPUTrainerParameterServer(nr_gpu) trainer = SyncMultiGPUTrainerParameterServer(nr_gpu)
assert isinstance(trainer, Trainer), trainer assert isinstance(trainer, Trainer), trainer
assert not isinstance(trainer, DistributedTrainerBase) assert not isinstance(trainer, DistributedTrainerBase)
...@@ -219,6 +211,7 @@ class KerasModel(object): ...@@ -219,6 +211,7 @@ class KerasModel(object):
self._stats_to_inference = loss + metrics + [TOTAL_LOSS_NAME] self._stats_to_inference = loss + metrics + [TOTAL_LOSS_NAME]
setup_keras_trainer( setup_keras_trainer(
self.trainer, get_model=self.get_model, self.trainer, get_model=self.get_model,
inputs_desc=self.inputs_desc, targets_desc=self.targets_desc,
input=self.input, input=self.input,
optimizer=optimizer, optimizer=optimizer,
loss=loss, loss=loss,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment