Commit 973a29f9 authored by Yuxin Wu's avatar Yuxin Wu

split imagenet_resnet_utils into two files

parent 15260844
...@@ -31,7 +31,7 @@ df = BatchData(df, 128) ...@@ -31,7 +31,7 @@ df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel # start 3 processes to run the dataflow in parallel
df = PrefetchDataZMQ(df, 3) df = PrefetchDataZMQ(df, 3)
```` ````
You can find more complicated DataFlow in the [ResNet training script](../examples/ResNet/imagenet_resnet_utils.py) You can find more complicated DataFlow in the [ResNet training script](../examples/ResNet/imagenet_utils.py)
with all the data preprocessing. with all the data preprocessing.
Unless you are working with standard data types (image folders, LMDB, etc), Unless you are working with standard data types (image folders, LMDB, etc),
......
...@@ -66,7 +66,7 @@ We will now add the cheapest pre-processing now to get an ndarray in the end ins ...@@ -66,7 +66,7 @@ We will now add the cheapest pre-processing now to get an ndarray in the end ins
ds = AugmentImageComponent(ds, [imgaug.Resize(224)]) ds = AugmentImageComponent(ds, [imgaug.Resize(224)])
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
You'll start to observe slow down after adding more pre-processing (such as those in the [ResNet example](../examples/ResNet/imagenet_resnet_utils.py)). You'll start to observe slow down after adding more pre-processing (such as those in the [ResNet example](../examples/ResNet/imagenet_utils.py)).
Now it's time to add threads or processes: Now it's time to add threads or processes:
```eval_rst ```eval_rst
.. code-block:: python .. code-block:: python
......
...@@ -17,12 +17,13 @@ from tensorpack.dataflow import imgaug, FakeData ...@@ -17,12 +17,13 @@ from tensorpack.dataflow import imgaug, FakeData
from tensorpack.tfutils import argscope, get_model_loader from tensorpack.tfutils import argscope, get_model_loader
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import ( from imagenet_utils import (
fbresnet_augmentor, get_imagenet_dataflow, fbresnet_augmentor, get_imagenet_dataflow, ImageNetModel,
eval_on_ILSVRC12)
from resnet_model import (
preresnet_group, preresnet_basicblock, preresnet_bottleneck, preresnet_group, preresnet_basicblock, preresnet_bottleneck,
resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck, resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck,
resnet_backbone, ImageNetModel, resnet_backbone)
eval_on_ILSVRC12)
TOTAL_BATCH_SIZE = 256 TOTAL_BATCH_SIZE = 256
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: imagenet_resnet_utils.py # File: imagenet_utils.py
import numpy as np
import cv2 import cv2
import numpy as np
import multiprocessing import multiprocessing
from abc import abstractmethod
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer from abc import abstractmethod
from tensorpack import imgaug, dataset, ModelDesc, InputDesc from tensorpack import imgaug, dataset, ModelDesc, InputDesc
from tensorpack.dataflow import ( from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ, AugmentImageComponent, PrefetchDataZMQ,
BatchData, ThreadedMapData) BatchData, ThreadedMapData)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
from tensorpack.utils.stats import RatioCounter from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.argscope import argscope, get_arg_scope from tensorpack.models import regularize_cost
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.models import (
Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU, FullyConnected,
LinearWrap, regularize_cost)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
class GoogleNetResize(imgaug.ImageAugmentor): class GoogleNetResize(imgaug.ImageAugmentor):
""" """
...@@ -60,7 +55,7 @@ def fbresnet_augmentor(isTrain): ...@@ -60,7 +55,7 @@ def fbresnet_augmentor(isTrain):
[imgaug.BrightnessScale((0.6, 1.4), clip=False), [imgaug.BrightnessScale((0.6, 1.4), clip=False),
imgaug.Contrast((0.6, 1.4), clip=False), imgaug.Contrast((0.6, 1.4), clip=False),
imgaug.Saturation(0.4, rgb=False), imgaug.Saturation(0.4, rgb=False),
# rgb-bgr conversion # rgb-bgr conversion for the constants copied from fb.resnet.torch
imgaug.Lighting(0.1, imgaug.Lighting(0.1,
eigval=np.asarray( eigval=np.asarray(
[0.2175, 0.0188, 0.0045][::-1]) * 255.0, [0.2175, 0.0188, 0.0045][::-1]) * 255.0,
...@@ -111,118 +106,6 @@ def get_imagenet_dataflow( ...@@ -111,118 +106,6 @@ def get_imagenet_dataflow(
return ds return ds
def resnet_shortcut(l, n_out, stride, nl=tf.identity):
data_format = get_arg_scope()['Conv2D']['data_format']
n_in = l.get_shape().as_list()[1 if data_format == 'NCHW' else 3]
if n_in != n_out: # change dimension when channel is not the same
return Conv2D('convshortcut', l, n_out, 1, stride=stride, nl=nl)
else:
return l
def apply_preactivation(l, preact):
if preact == 'bnrelu':
# this is used only for preact-resnet
shortcut = l # preserve identity mapping
l = BNReLU('preact', l)
else:
shortcut = l
return l, shortcut
def get_bn(zero_init=False):
"""
Zero init gamma is good for resnet. See https://arxiv.org/abs/1706.02677.
"""
if zero_init:
return lambda x, name: BatchNorm('bn', x, gamma_init=tf.zeros_initializer())
else:
return lambda x, name: BatchNorm('bn', x)
def preresnet_basicblock(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3)
return l + resnet_shortcut(shortcut, ch_out, stride)
def preresnet_bottleneck(l, ch_out, stride, preact):
# stride is applied on the second conv, following fb.resnet.torch
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1)
return l + resnet_shortcut(shortcut, ch_out * 4, stride)
def preresnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
# first block doesn't need activation
l = block_func(l, features,
stride if i == 0 else 1,
'no_preact' if i == 0 else 'bnrelu')
# end of each group need an extra activation
l = BNReLU('bnlast', l)
return l
def resnet_basicblock(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out, stride, nl=get_bn(zero_init=False))
def resnet_bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def se_resnet_bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def resnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features,
stride if i == 0 else 1, 'no_preact')
# end of each block need an activation
l = tf.nn.relu(l)
return l
def resnet_backbone(image, num_blocks, group_func, block_func):
with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')):
logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME')
.apply(group_func, 'group0', block_func, 64, num_blocks[0], 1)
.apply(group_func, 'group1', block_func, 128, num_blocks[1], 2)
.apply(group_func, 'group2', block_func, 256, num_blocks[2], 2)
.apply(group_func, 'group3', block_func, 512, num_blocks[3], 2)
.GlobalAvgPooling('gap')
.FullyConnected('linear', 1000, nl=tf.identity)())
return logits
def eval_on_ILSVRC12(model, sessinit, dataflow): def eval_on_ILSVRC12(model, sessinit, dataflow):
pred_config = PredictConfig( pred_config = PredictConfig(
model=model, model=model,
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: resnet_model.py
import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.models import (
Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm, BNReLU, FullyConnected,
LinearWrap)
def resnet_shortcut(l, n_out, stride, nl=tf.identity):
data_format = get_arg_scope()['Conv2D']['data_format']
n_in = l.get_shape().as_list()[1 if data_format == 'NCHW' else 3]
if n_in != n_out: # change dimension when channel is not the same
return Conv2D('convshortcut', l, n_out, 1, stride=stride, nl=nl)
else:
return l
def apply_preactivation(l, preact):
if preact == 'bnrelu':
# this is used only for preact-resnet
shortcut = l # preserve identity mapping
l = BNReLU('preact', l)
else:
shortcut = l
return l, shortcut
def get_bn(zero_init=False):
"""
Zero init gamma is good for resnet. See https://arxiv.org/abs/1706.02677.
"""
if zero_init:
return lambda x, name: BatchNorm('bn', x, gamma_init=tf.zeros_initializer())
else:
return lambda x, name: BatchNorm('bn', x)
def preresnet_basicblock(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3)
return l + resnet_shortcut(shortcut, ch_out, stride)
def preresnet_bottleneck(l, ch_out, stride, preact):
# stride is applied on the second conv, following fb.resnet.torch
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1)
return l + resnet_shortcut(shortcut, ch_out * 4, stride)
def preresnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
# first block doesn't need activation
l = block_func(l, features,
stride if i == 0 else 1,
'no_preact' if i == 0 else 'bnrelu')
# end of each group need an extra activation
l = BNReLU('bnlast', l)
return l
def resnet_basicblock(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out, stride, nl=get_bn(zero_init=False))
def resnet_bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def se_resnet_bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
squeeze = GlobalAvgPooling('gap', l)
squeeze = FullyConnected('fc1', squeeze, ch_out // 4, nl=tf.nn.relu)
squeeze = FullyConnected('fc2', squeeze, ch_out * 4, nl=tf.nn.sigmoid)
l = l * tf.reshape(squeeze, [-1, ch_out * 4, 1, 1])
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def resnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features,
stride if i == 0 else 1, 'no_preact')
# end of each block need an activation
l = tf.nn.relu(l)
return l
def resnet_backbone(image, num_blocks, group_func, block_func):
with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')):
logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME')
.apply(group_func, 'group0', block_func, 64, num_blocks[0], 1)
.apply(group_func, 'group1', block_func, 128, num_blocks[1], 2)
.apply(group_func, 'group2', block_func, 256, num_blocks[2], 2)
.apply(group_func, 'group3', block_func, 512, num_blocks[3], 2)
.GlobalAvgPooling('gap')
.FullyConnected('linear', 1000, nl=tf.identity)())
return logits
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment