Commit 6926d22a authored by Yuxin Wu's avatar Yuxin Wu

MultiGPU version of "mnist-keras.py"

parent f7fecef9
......@@ -15,9 +15,14 @@ are the only two tools I know that can scale the training of a large Keras model
### Simple Examples:
[mnist-keras.py](mnist-keras.py): a simple MNIST model written mostly in tensorpack style, but use Keras model as symbolic functions.
There are two flavors where you can use a Keras model inside tensorpack:
[mnist-keras-v2.py](mnist-keras-v2.py): the same MNIST model written in Keras style.
1. Write the tower function similar to a standard tensorpack program, but use some Keras layers in
between. See [mnist-keras.py](mnist-keras.py) on how to do this.
It does not support all tensorpack trainers.
2. The entire model to train is a Keras model (and there will be no `ModelDesc`, etc).
See [mnist-keras-v2.py](mnist-keras-v2.py).
### ImageNet Example:
......
......@@ -10,6 +10,7 @@ from tensorpack import *
from tensorpack.contrib.keras import KerasPhaseCallback
from tensorpack.dataflow import dataset
from tensorpack.utils.argtools import memoized
from tensorpack.utils.gpu import get_num_gpu
KL = keras.layers
......@@ -23,8 +24,9 @@ Note: this example does not work for replicated-style data-parallel trainers.
IMAGE_SIZE = 28
@memoized # this is necessary for sonnet/Keras to work under tensorpack
@memoized # this is necessary for sonnet/keras to work under tensorpack
def get_keras_model():
with tf.name_scope('/'):
M = keras.models.Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
M.add(KL.MaxPooling2D())
......@@ -96,4 +98,11 @@ if __name__ == '__main__':
max_epoch=100,
)
if get_num_gpu() <= 1:
# single GPU:
launch_train_with_config(cfg, QueueInputTrainer())
else:
# multi GPU:
launch_train_with_config(cfg, SyncMultiGPUTrainerParameterServer(2))
# "Replicated" multi-gpu trainer is not supported for Keras model
# since Keras does not respect variable scopes.
......@@ -40,7 +40,7 @@ class KerasModelCaller(object):
self.get_model = get_model
self.cached_model = None
def __call__(self, input_tensors):
def __call__(self, *input_tensors):
"""
Args:
input_tensors ([tf.Tensor])
......@@ -106,6 +106,8 @@ class KerasModelCaller(object):
post_process_model(model)
if isinstance(outputs, list) and len(outputs) == 1:
return outputs[0]
return outputs
......@@ -167,7 +169,7 @@ def setup_keras_trainer(
target_tensors = list(inputs[nr_inputs:])
# TODO mapping between target tensors & output tensors
outputs = model_caller(input_tensors)
outputs = model_caller(*input_tensors)
if isinstance(outputs, tf.Tensor):
outputs = [outputs]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment