Commit 6926d22a authored by Yuxin Wu's avatar Yuxin Wu

MultiGPU version of "mnist-keras.py"

parent f7fecef9
...@@ -15,9 +15,14 @@ are the only two tools I know that can scale the training of a large Keras model ...@@ -15,9 +15,14 @@ are the only two tools I know that can scale the training of a large Keras model
### Simple Examples: ### Simple Examples:
[mnist-keras.py](mnist-keras.py): a simple MNIST model written mostly in tensorpack style, but use Keras model as symbolic functions. There are two flavors where you can use a Keras model inside tensorpack:
[mnist-keras-v2.py](mnist-keras-v2.py): the same MNIST model written in Keras style. 1. Write the tower function similar to a standard tensorpack program, but use some Keras layers in
between. See [mnist-keras.py](mnist-keras.py) on how to do this.
It does not support all tensorpack trainers.
2. The entire model to train is a Keras model (and there will be no `ModelDesc`, etc).
See [mnist-keras-v2.py](mnist-keras-v2.py).
### ImageNet Example: ### ImageNet Example:
...@@ -36,7 +41,7 @@ It has: ...@@ -36,7 +41,7 @@ It has:
Keras does not respect variable scopes or variable Keras does not respect variable scopes or variable
collections, which contradicts with tensorpack trainers. collections, which contradicts with tensorpack trainers.
Therefore Keras support is __experimental__. Therefore Keras support is __experimental__.
These simple examples can run within tensorpack smoothly, but note that a future These simple examples can run within tensorpack smoothly, but note that a future
version of Keras or a complicated model may break them (unlikely, though). version of Keras or a complicated model may break them (unlikely, though).
...@@ -10,6 +10,7 @@ from tensorpack import * ...@@ -10,6 +10,7 @@ from tensorpack import *
from tensorpack.contrib.keras import KerasPhaseCallback from tensorpack.contrib.keras import KerasPhaseCallback
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
from tensorpack.utils.gpu import get_num_gpu
KL = keras.layers KL = keras.layers
...@@ -23,19 +24,20 @@ Note: this example does not work for replicated-style data-parallel trainers. ...@@ -23,19 +24,20 @@ Note: this example does not work for replicated-style data-parallel trainers.
IMAGE_SIZE = 28 IMAGE_SIZE = 28
@memoized # this is necessary for sonnet/Keras to work under tensorpack @memoized # this is necessary for sonnet/keras to work under tensorpack
def get_keras_model(): def get_keras_model():
M = keras.models.Sequential() with tf.name_scope('/'):
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same')) M = keras.models.Sequential()
M.add(KL.MaxPooling2D()) M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D()) M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.Conv2D(32, 3, padding='same', activation='relu')) M.add(KL.MaxPooling2D())
M.add(KL.Flatten()) M.add(KL.Conv2D(32, 3, padding='same', activation='relu'))
M.add(KL.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-5))) M.add(KL.Flatten())
M.add(KL.Dropout(0.5)) M.add(KL.Dense(512, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5))) M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5)))
return M return M
...@@ -96,4 +98,11 @@ if __name__ == '__main__': ...@@ -96,4 +98,11 @@ if __name__ == '__main__':
max_epoch=100, max_epoch=100,
) )
launch_train_with_config(cfg, QueueInputTrainer()) if get_num_gpu() <= 1:
# single GPU:
launch_train_with_config(cfg, QueueInputTrainer())
else:
# multi GPU:
launch_train_with_config(cfg, SyncMultiGPUTrainerParameterServer(2))
# "Replicated" multi-gpu trainer is not supported for Keras model
# since Keras does not respect variable scopes.
...@@ -40,7 +40,7 @@ class KerasModelCaller(object): ...@@ -40,7 +40,7 @@ class KerasModelCaller(object):
self.get_model = get_model self.get_model = get_model
self.cached_model = None self.cached_model = None
def __call__(self, input_tensors): def __call__(self, *input_tensors):
""" """
Args: Args:
input_tensors ([tf.Tensor]) input_tensors ([tf.Tensor])
...@@ -106,6 +106,8 @@ class KerasModelCaller(object): ...@@ -106,6 +106,8 @@ class KerasModelCaller(object):
post_process_model(model) post_process_model(model)
if isinstance(outputs, list) and len(outputs) == 1:
return outputs[0]
return outputs return outputs
...@@ -167,7 +169,7 @@ def setup_keras_trainer( ...@@ -167,7 +169,7 @@ def setup_keras_trainer(
target_tensors = list(inputs[nr_inputs:]) target_tensors = list(inputs[nr_inputs:])
# TODO mapping between target tensors & output tensors # TODO mapping between target tensors & output tensors
outputs = model_caller(input_tensors) outputs = model_caller(*input_tensors)
if isinstance(outputs, tf.Tensor): if isinstance(outputs, tf.Tensor):
outputs = [outputs] outputs = [outputs]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment