Commit 760a8112 authored by Yuxin Wu's avatar Yuxin Wu

clean-up load-resnet.py

parent 753afd0a
......@@ -193,7 +193,7 @@ class ImageNetModel(ModelDesc):
image: 4D tensor of 224x224 in ``self.data_format``
Returns:
Bx1000 logits
Nx1000 logits
"""
def _get_optimizer(self):
......
......@@ -5,6 +5,7 @@
# Yuxin Wu <ppwwyyxx@gmail.com>
import cv2
import functools
import tensorflow as tf
import argparse
import os
......@@ -22,9 +23,14 @@ from tensorpack.tfutils.summary import *
from tensorpack.dataflow.dataset import ILSVRCMeta, ILSVRC12
from imagenet_utils import eval_on_ILSVRC12, get_imagenet_dataflow
from resnet_model import resnet_group, apply_preactivation, resnet_shortcut, get_bn
from resnet_model import resnet_group, resnet_bottleneck
MODEL_DEPTH = None
DEPTH = None
CFG = {
50: ([3, 4, 6, 3]),
101: ([3, 4, 23, 3]),
152: ([3, 8, 36, 3])
}
class Model(ModelDesc):
......@@ -34,20 +40,9 @@ class Model(ModelDesc):
def _build_graph(self, inputs):
image, label = inputs
blocks = CFG[DEPTH]
def bottleneck(l, ch_out, stride, preact):
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, stride=stride, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn())
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn())
cfg = {
50: ([3, 4, 6, 3]),
101: ([3, 4, 23, 3]),
152: ([3, 8, 36, 3])
}
defs = cfg[MODEL_DEPTH]
bottleneck = functools.partial(resnet_bottleneck, stride_first=True)
# tensorflow with padding=SAME will by default pad [2,3] here.
# but caffe conv with stride will pad [3,3]
......@@ -60,10 +55,10 @@ class Model(ModelDesc):
logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU, padding='VALID')
.MaxPooling('pool0', shape=3, stride=2, padding='SAME')
.apply(resnet_group, 'group0', bottleneck, 64, defs[0], 1)
.apply(resnet_group, 'group1', bottleneck, 128, defs[1], 2)
.apply(resnet_group, 'group2', bottleneck, 256, defs[2], 2)
.apply(resnet_group, 'group3', bottleneck, 512, defs[3], 2)
.apply(resnet_group, 'group0', bottleneck, 64, blocks[0], 1)
.apply(resnet_group, 'group1', bottleneck, 128, blocks[1], 2)
.apply(resnet_group, 'group2', bottleneck, 256, blocks[2], 2)
.apply(resnet_group, 'group3', bottleneck, 512, blocks[3], 2)
.GlobalAvgPooling('gap')
.FullyConnected('linear', 1000, nl=tf.identity)())
prob = tf.nn.softmax(logits, name='prob')
......@@ -143,43 +138,49 @@ def name_conversion(caffe_layer_name):
layer_id = re.search('_branch[0-9]([a-z])/', caffe_layer_name).group(1)
layer_id = ord(layer_id) - ord('a') + 1
TYPE_DICT = {'res': 'conv', 'bn': 'bn'}
TYPE_DICT = {'res': 'conv{}', 'bn': 'conv{}/bn'}
layer_type = TYPE_DICT[layer_type].format(layer_id if layer_branch == 2 else 'shortcut')
tf_name = caffe_layer_name[caffe_layer_name.index('/'):]
if layer_type == 'res':
layer_type = 'conv{}'.format(layer_id if layer_branch == 2 else 'shortcut')
else:
layer_type = 'conv{}/bn'.format(layer_id if layer_branch == 2 else 'shortcut')
tf_name = 'group{}/block{}/{}'.format(
int(layer_group) - 2, layer_block, layer_type) + tf_name
return tf_name
def convert_param_name(param):
resnet_param = {}
for k, v in six.iteritems(param):
try:
newname = name_conversion(k)
except:
logger.error("Exception when processing caffe layer {}".format(k))
raise
logger.info("Name Transform: " + k + ' --> ' + newname)
resnet_param[newname] = v
return resnet_param
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load', required=True,
help='.npy model file generated by tensorpack.utils.loadcaffe')
parser.add_argument('-d', '--depth', help='resnet depth', required=True, type=int, choices=[50, 101, 152])
parser.add_argument('--input', help='an input image')
parser.add_argument('--convert', help='npz output file to save the converted model')
parser.add_argument('--eval', help='ILSVRC dir to run validation on')
args = parser.parse_args()
assert args.input or args.eval, "Choose either input or eval!"
MODEL_DEPTH = args.depth
DEPTH = args.depth
param = np.load(args.load, encoding='latin1').item()
resnet_param = {}
for k, v in six.iteritems(param):
try:
newname = name_conversion(k)
except:
logger.error("Exception when processing caffe layer {}".format(k))
raise
logger.info("Name Transform: " + k + ' --> ' + newname)
resnet_param[newname] = v
param = convert_param_name(param)
if args.convert:
assert args.convert.endswith('.npz')
np.savez_compressed(args.convert, **param)
if args.eval:
ds = get_imagenet_dataflow(args.eval, 'val', 128, get_inference_augmentor())
eval_on_ILSVRC12(Model(), DictRestore(resnet_param), ds)
else:
run_test(resnet_param, args.input)
eval_on_ILSVRC12(Model(), DictRestore(param), ds)
elif args.input:
run_test(param, args.input)
......@@ -77,10 +77,13 @@ def resnet_basicblock(l, ch_out, stride, preact):
return l + resnet_shortcut(shortcut, ch_out, stride, nl=get_bn(zero_init=False))
def resnet_bottleneck(l, ch_out, stride, preact):
def resnet_bottleneck(l, ch_out, stride, preact, stride_first=False):
"""
stride_first: original resnet put stride on first conv. fb.resnet.torch put stride on second conv.
"""
l, shortcut = apply_preactivation(l, preact)
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv1', l, ch_out, 1, stride=stride if stride_first else 1, nl=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, stride=1 if stride_first else stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment