Commit 5c241e09 authored by Yuxin Wu's avatar Yuxin Wu

Faster dataflow for resnet inference (#139)

parent eecb5803
...@@ -31,7 +31,7 @@ df = BatchData(df, 128) ...@@ -31,7 +31,7 @@ df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel # start 3 processes to run the dataflow in parallel
df = PrefetchDataZMQ(df, 3) df = PrefetchDataZMQ(df, 3)
```` ````
You can find more complicated DataFlow in the [ResNet training script](../examples/ResNet/imagenet-resnet.py) You can find more complicated DataFlow in the [ResNet training script](../examples/ResNet/imagenet_resnet_utils.py)
with all the data preprocessing. with all the data preprocessing.
Unless you are working with standard data types (image folders, LMDB, etc), Unless you are working with standard data types (image folders, LMDB, etc),
......
...@@ -66,7 +66,7 @@ We will now add the cheapest pre-processing now to get an ndarray in the end ins ...@@ -66,7 +66,7 @@ We will now add the cheapest pre-processing now to get an ndarray in the end ins
ds = AugmentImageComponent(ds, [imgaug.Resize(224)]) ds = AugmentImageComponent(ds, [imgaug.Resize(224)])
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
You'll start to observe slow down after adding more pre-processing (such as those in the [ResNet example](../examples/ResNet/imagenet-resnet.py)). You'll start to observe slow down after adding more pre-processing (such as those in the [ResNet example](../examples/ResNet/imagenet_resnet_utils.py)).
Now it's time to add threads or processes: Now it's time to add threads or processes:
```eval_rst ```eval_rst
.. code-block:: python .. code-block:: python
......
...@@ -18,7 +18,8 @@ from tensorpack.utils.gpu import get_nr_gpu ...@@ -18,7 +18,8 @@ from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import ( from imagenet_resnet_utils import (
fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone, fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone,
eval_on_ILSVRC12, image_preprocess, compute_loss_and_error) eval_on_ILSVRC12, image_preprocess, compute_loss_and_error,
get_imagenet_dataflow)
TOTAL_BATCH_SIZE = 256 TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224 INPUT_SHAPE = 224
...@@ -66,19 +67,12 @@ class Model(ModelDesc): ...@@ -66,19 +67,12 @@ class Model(ModelDesc):
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(train_or_test): def get_data(name):
isTrain = train_or_test == 'train' isTrain = name == 'train'
datadir = args.data datadir = args.data
ds = dataset.ILSVRC12(datadir, train_or_test,
shuffle=isTrain, dir_structure='train')
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
return get_imagenet_dataflow(
ds = AugmentImageComponent(ds, augmentors, copy=False) datadir, name, BATCH_SIZE, augmentors, dir_structure='original')
if isTrain:
ds = PrefetchDataZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
return ds
def get_config(): def get_config():
......
...@@ -6,19 +6,23 @@ import sys ...@@ -6,19 +6,23 @@ import sys
import argparse import argparse
import numpy as np import numpy as np
import os import os
import multiprocessing
import tensorflow as tf import tensorflow as tf
from tensorpack import * from tensorpack import InputDesc, ModelDesc, logger
from tensorpack.tfutils.symbolic_functions import * from tensorpack.models import *
from tensorpack.tfutils.summary import * from tensorpack.callbacks import *
from tensorpack.dataflow import dataset from tensorpack.train import TrainConfig, SyncMultiGPUTrainerParameterServer
from tensorpack.dataflow import imgaug, FakeData
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils import argscope, SaverRestore
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import ( from imagenet_resnet_utils import (
fbresnet_augmentor, resnet_basicblock, resnet_bottleneck, resnet_backbone, fbresnet_augmentor, resnet_basicblock, resnet_bottleneck, resnet_backbone,
eval_on_ILSVRC12, image_preprocess, compute_loss_and_error) eval_on_ILSVRC12, image_preprocess, compute_loss_and_error,
get_imagenet_dataflow)
TOTAL_BATCH_SIZE = 256 TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224 INPUT_SHAPE = 224
...@@ -63,24 +67,17 @@ class Model(ModelDesc): ...@@ -63,24 +67,17 @@ class Model(ModelDesc):
self.cost = tf.add_n([loss, wd_loss], name='cost') self.cost = tf.add_n([loss, wd_loss], name='cost')
def _get_optimizer(self): def _get_optimizer(self):
lr = get_scalar_var('learning_rate', 0.1, summary=True) lr = symbf.get_scalar_var('learning_rate', 0.1, summary=True)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(train_or_test): def get_data(name):
isTrain = train_or_test == 'train' isTrain = name == 'train'
datadir = args.data
ds = dataset.ILSVRC12(datadir, train_or_test,
shuffle=isTrain, dir_structure='original')
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
augmentors.append(imgaug.ToUint8()) augmentors.append(imgaug.ToUint8())
datadir = args.data
ds = AugmentImageComponent(ds, augmentors, copy=False) return get_imagenet_dataflow(
if isTrain: datadir, name, BATCH_SIZE, augmentors, dir_structure='original')
ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
return ds
def get_config(fake=False, data_format='NCHW'): def get_config(fake=False, data_format='NCHW'):
......
...@@ -4,12 +4,16 @@ ...@@ -4,12 +4,16 @@
import numpy as np import numpy as np
import cv2 import cv2
import multiprocessing
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer from tensorflow.contrib.layers import variance_scaling_initializer
import tensorpack as tp import tensorpack as tp
from tensorpack import imgaug from tensorpack import imgaug, dataset
from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ,
BatchData, ThreadedMapData)
from tensorpack.utils.stats import RatioCounter from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.argscope import argscope, get_arg_scope from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
...@@ -75,6 +79,37 @@ def fbresnet_augmentor(isTrain): ...@@ -75,6 +79,37 @@ def fbresnet_augmentor(isTrain):
return augmentors return augmentors
def get_imagenet_dataflow(
datadir, name, batch_size,
augmentors, dir_structure='original'):
"""
See explanations in the tutorial:
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
"""
assert name in ['train', 'val', 'test']
isTrain = name == 'train'
cpu = min(30, multiprocessing.cpu_count())
if isTrain:
ds = dataset.ILSVRC12(datadir, name, shuffle=True)
ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = PrefetchDataZMQ(ds, cpu)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, name,
shuffle=False, dir_structure=dir_structure)
aug = imgaug.AugmentorList(augmentors)
def mapf(dp):
fname, cls = dp
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = aug.augment(im)
return im, cls
ds = ThreadedMapData(ds, cpu, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
return ds
def resnet_shortcut(l, n_out, stride): def resnet_shortcut(l, n_out, stride):
data_format = get_arg_scope()['Conv2D']['data_format'] data_format = get_arg_scope()['Conv2D']['data_format']
n_in = l.get_shape().as_list()[1 if data_format == 'NCHW' else 3] n_in = l.get_shape().as_list()[1 if data_format == 'NCHW' else 3]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment