Commit dd7ddd94 authored by Yuxin Wu's avatar Yuxin Wu

Keras ImageNet example (#160)

parent ac15b641
## Keras + Tensorpack
Use Keras to define a model a train it with efficient tensorpack trainers.
### Simple Examples:
[mnist-keras.py](mnist-keras.py): a simple MNIST model written mostly in tensorpack style, but use Keras model as symbolic functions.
[mnist-keras-v2.py](mnist-keras-v2.py): the same MNIST model written in Keras style.
### ImageNet Example:
[imagenet-resnet-keras.py](imagenet-resnet-keras.py):
reproduce exactly the same setting of [tensorpack ResNet example](../ResNet) on ImageNet.
It has:
+ ResNet-50 model modified from [keras.applications](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/_impl/keras/applications/resnet50.py)
+ Multi-GPU data-parallel __training and validation__ which scales (With 8 V100s, still has >90% GPU utilization and finished training in 21 hours)
+ Good accuracy (same as the [tensorpack ResNet example](../ResNet))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: imagenet-resnet-keras.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import tensorflow as tf
import argparse
from tensorpack import InputDesc, QueueInput, StagingInput, SyncMultiGPUTrainerReplicated
from tensorpack.dataflow import MapData, FakeData, MapDataComponent
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.contrib.keras import KerasModel
from tensorpack.callbacks import *
from tensorpack.tfutils import get_current_tower_context
from tensorflow import keras
from tensorflow.python.keras.layers import *
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor, ImageNetModel
TOTAL_BATCH_SIZE = 512
BASE_LR = 0.1 * (TOTAL_BATCH_SIZE // 256)
def bn(x, name, zero_init=False):
return BatchNormalization(
axis=1, name=name, fused=True,
momentum=0.9, epsilon=1e-5,
gamma_initializer='zeros' if zero_init else 'ones')(x)
def conv(x, filters, kernel, strides=1, name=None):
return Conv2D(filters, kernel, name=name,
strides=strides, use_bias=False, padding='same',
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2.0, mode='fan_out', distribution='normal'),
kernel_regularizer=tf.keras.regularizers.l2(5e-5))(x)
def identity_block(input_tensor, kernel_size, filters, stage, block):
filters1, filters2, filters3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = conv(input_tensor, filters1, 1, name=conv_name_base + '2a')
x = bn(x, name=bn_name_base + '2a')
x = Activation('relu')(x)
x = conv(x, filters2, kernel_size, name=conv_name_base + '2b')
x = bn(x, name=bn_name_base + '2b')
x = Activation('relu')(x)
x = conv(x, filters3, (1, 1), name=conv_name_base + '2c')
x = bn(x, name=bn_name_base + '2c', zero_init=True)
x = tf.keras.layers.add([x, input_tensor])
x = Activation('relu')(x)
return x
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
filters1, filters2, filters3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = conv(input_tensor, filters1, (1, 1), name=conv_name_base + '2a')
x = bn(x, name=bn_name_base + '2a')
x = Activation('relu')(x)
x = conv(x, filters2, kernel_size, strides=strides, name=conv_name_base + '2b')
x = bn(x, name=bn_name_base + '2b')
x = Activation('relu')(x)
x = conv(x, filters3, (1, 1), name=conv_name_base + '2c')
x = bn(x, name=bn_name_base + '2c', zero_init=True)
shortcut = conv(input_tensor,
filters3, (1, 1), strides=strides,
name=conv_name_base + '1')
shortcut = bn(shortcut, name=bn_name_base + '1')
x = tf.keras.layers.add([x, shortcut])
x = Activation('relu')(x)
return x
def resnet50(inputs):
input = tf.layers.Input(tensor=inputs[0])
def image_preprocess(image):
image = ImageNetModel.image_preprocess(image)
image = tf.transpose(image, [0, 3, 1, 2])
return image
x = Lambda(image_preprocess)(input)
x = conv(x, 64, (7, 7), strides=(2, 2), name='conv0')
x = bn(x, name='bn_conv1')
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Flatten()(x)
x = Dense(1000, activation='softmax', name='fc1000',
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2.0, mode='fan_in'),
kernel_regularizer=tf.keras.regularizers.l2(5e-5))(x)
M = tf.keras.models.Model(input, x, name='resnet50')
return M
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
args = parser.parse_args()
logger.set_logger_dir("train_log/dynamic-bn")
tf.keras.backend.set_image_data_format('channels_first')
nr_gpu = get_nr_gpu()
if args.fake:
df_train = FakeData([[64, 224, 224, 3], [64, 1000]], 5000, random=False, dtype='uint8')
df_val = FakeData([[64, 224, 224, 3], [64, 1000]], 5000, random=False)
else:
batch_size = TOTAL_BATCH_SIZE // nr_gpu
assert args.data is not None
df_train = get_imagenet_dataflow(
args.data, 'train', batch_size, fbresnet_augmentor(True))
df_val = get_imagenet_dataflow(
args.data, 'val', batch_size, fbresnet_augmentor(False))
def one_hot(label):
return np.eye(1000)[label]
df_train = MapDataComponent(df_train, one_hot, 1)
df_val = MapDataComponent(df_val, one_hot, 1)
input = QueueInput(df_train)
input = StagingInput(input)
M = KerasModel(
resnet50,
inputs_desc=[InputDesc(tf.uint8, [None, 224, 224, 3], 'images')],
targets_desc=[InputDesc(tf.float32, [None, 1000], 'labels')],
input=input,
trainer=SyncMultiGPUTrainerReplicated(nr_gpu))
lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
tf.summary.scalar('lr', lr)
M.compile(
optimizer=tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True),
loss='categorical_crossentropy',
metrics='categorical_accuracy'
)
callbacks = [
ModelSaver(),
ScheduledHyperParamSetter(
'learning_rate',
[(0, 0.1), (3, BASE_LR)], interp='linear'), # warmup
ScheduledHyperParamSetter(
'learning_rate',
[(30, BASE_LR * 0.1), (60, BASE_LR * 1e-2), (85, BASE_LR * 1e-3), (100, BASE_LR * 1e-4)]),
GPUUtilizationTracker()
]
if not args.fake:
callbacks.append(
DataParallelInferenceRunner(
df_val, ScalarStats(['categorical_accuracy']), list(range(nr_gpu))))
M.fit(
steps_per_epoch=100 if args.fake else 1281167 // TOTAL_BATCH_SIZE,
max_epoch=105,
callbacks=callbacks
)
../ResNet/imagenet_utils.py
\ No newline at end of file
...@@ -10,6 +10,8 @@ KL = keras.layers ...@@ -10,6 +10,8 @@ KL = keras.layers
""" """
This is an mnist example demonstrating how to use Keras symbolic function inside tensorpack. This is an mnist example demonstrating how to use Keras symbolic function inside tensorpack.
This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack. This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack.
Note: this example does not work for replicated-style data-parallel trainers.
""" """
......
...@@ -11,7 +11,7 @@ from ..models.regularize import regularize_cost_from_collection ...@@ -11,7 +11,7 @@ from ..models.regularize import regularize_cost_from_collection
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
from ..train.trainers import DistributedTrainerBase from ..train.trainers import DistributedTrainerBase
from ..callbacks import ( from ..callbacks import (
Callback, InferenceRunner, CallbackToHook, Callback, InferenceRunnerBase, InferenceRunner, CallbackToHook,
ScalarStats) ScalarStats)
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
...@@ -19,6 +19,7 @@ from ..tfutils.tower import get_current_tower_context ...@@ -19,6 +19,7 @@ from ..tfutils.tower import get_current_tower_context
from ..tfutils.scope_utils import cached_name_scope from ..tfutils.scope_utils import cached_name_scope
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu from ..utils.gpu import get_nr_gpu
from ..utils import logger
__all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel'] __all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel']
...@@ -56,6 +57,8 @@ class KerasModelCaller(object): ...@@ -56,6 +57,8 @@ class KerasModelCaller(object):
if reuse: if reuse:
# use the cached Keras model to mimic reuse # use the cached Keras model to mimic reuse
# NOTE: ctx.is_training won't be useful inside model,
# because inference will always use the cached Keras model
return self.cached_model.call(input_tensors) return self.cached_model.call(input_tensors)
else: else:
# create new Keras model if not reuse # create new Keras model if not reuse
...@@ -74,10 +77,12 @@ class KerasPhaseCallback(Callback): ...@@ -74,10 +77,12 @@ class KerasPhaseCallback(Callback):
self._learning_phase = keras.backend.learning_phase() self._learning_phase = keras.backend.learning_phase()
def _setup_graph(self): def _setup_graph(self):
logger.info("Using Keras leraning phase {} in the graph!".format(
self._learning_phase.name))
cbs = self.trainer._callbacks.cbs cbs = self.trainer._callbacks.cbs
for cb in cbs: for cb in cbs:
# XXX HACK # XXX HACK
if isinstance(cb, InferenceRunner): if isinstance(cb, InferenceRunnerBase):
h = CallbackToHook(KerasPhaseCallback(False)) h = CallbackToHook(KerasPhaseCallback(False))
cb.register_hook(h) cb.register_hook(h)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment