Commit fcdeafbf authored by Yuxin Wu's avatar Yuxin Wu

simplify code

parent fc81be3f
...@@ -10,10 +10,12 @@ import tensorflow as tf ...@@ -10,10 +10,12 @@ import tensorflow as tf
import argparse import argparse
import os, re import os, re
import numpy as np import numpy as np
import six
from six.moves import zip from six.moves import zip
from tensorflow.contrib.layers import variance_scaling_initializer from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import * from tensorpack import *
from tensorpack.utils import logger
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow.dataset import ILSVRCMeta from tensorpack.dataflow.dataset import ILSVRCMeta
...@@ -32,14 +34,14 @@ class Model(ModelDesc): ...@@ -32,14 +34,14 @@ class Model(ModelDesc):
def _build_graph(self, input_vars): def _build_graph(self, input_vars):
image = input_vars[0] image = input_vars[0]
def caffe_shortcut(l, n_in, n_out, stride): def shortcut(l, n_in, n_out, stride):
if n_in != n_out: if n_in != n_out:
l = Conv2D('convshortcut', l, n_out, 1, stride=stride) l = Conv2D('convshortcut', l, n_out, 1, stride=stride)
return BatchNorm('bnshortcut', l) return BatchNorm('bnshortcut', l)
else: else:
return l return l
def caffe_bottleneck(l, ch_out, stride, preact): def bottleneck(l, ch_out, stride, preact):
ch_in = l.get_shape().as_list()[-1] ch_in = l.get_shape().as_list()[-1]
input = l input = l
if preact == 'both_preact': if preact == 'both_preact':
...@@ -53,71 +55,61 @@ class Model(ModelDesc): ...@@ -53,71 +55,61 @@ class Model(ModelDesc):
l = tf.nn.relu(l) l = tf.nn.relu(l)
l = Conv2D('conv3', l, ch_out * 4, 1) l = Conv2D('conv3', l, ch_out * 4, 1)
l = BatchNorm('bn3', l) # put bn at the bottom l = BatchNorm('bn3', l) # put bn at the bottom
return l + caffe_shortcut(input, ch_in, ch_out * 4, stride) return l + shortcut(input, ch_in, ch_out * 4, stride)
def layer(l, layername, block_func, features, count, stride, first=False): def layer(l, layername, features, count, stride, first=False):
with tf.variable_scope(layername): with tf.variable_scope(layername):
with tf.variable_scope('block0'): with tf.variable_scope('block0'):
l = block_func(l, features, stride, l = bottleneck(l, features, stride,
'no_preact' if first else 'both_preact') 'no_preact' if first else 'both_preact')
for i in range(1, count): for i in range(1, count):
with tf.variable_scope('block{}'.format(i)): with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features, 1, 'both_preact') l = bottleneck(l, features, 1, 'both_preact')
return l return l
cfg = { cfg = {
50: ([3,4,6,3], caffe_bottleneck), 50: ([3,4,6,3]),
101: ([3,4,23,3], caffe_bottleneck), 101: ([3,4,23,3]),
152: ([3,8,36,3], caffe_bottleneck) 152: ([3,8,36,3])
} }
defs = cfg[MODEL_DEPTH]
defs, block_func = cfg[MODEL_DEPTH]
with argscope(Conv2D, nl=tf.identity, use_bias=False, with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')): W_init=variance_scaling_initializer(mode='FAN_OUT')):
fc1000l = (LinearWrap(image) fc1000 = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU ) .Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME') .MaxPooling('pool0', shape=3, stride=2, padding='SAME')
.apply(layer, 'group0', block_func, 64, defs[0], 1, first=True) .apply(layer, 'group0', 64, defs[0], 1, first=True)
.apply(layer, 'group1', block_func, 128, defs[1], 2) .apply(layer, 'group1', 128, defs[1], 2)
.apply(layer, 'group2', block_func, 256, defs[2], 2) .apply(layer, 'group2', 256, defs[2], 2)
.apply(layer, 'group3', block_func, 512, defs[3], 2) .apply(layer, 'group3', 512, defs[3], 2)
.tf.nn.relu() .tf.nn.relu()
.GlobalAvgPooling('gap') .GlobalAvgPooling('gap')
.FullyConnected('fc1000', 1000, nl=tf.identity)()) .FullyConnected('fc1000', 1000, nl=tf.identity)())
prob = tf.nn.softmax(fc1000, name='prob_output')
prob = tf.nn.softmax(fc1000l, name='prob_output') def run_test(params, input):
def run_test(path, input):
image_mean = np.array([0.485, 0.456, 0.406], dtype='float32') image_mean = np.array([0.485, 0.456, 0.406], dtype='float32')
param = np.load(path).item()
resNet_param = { caffeResNet2tensorpackResNet(k) :v for k, v in param.iteritems()}
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
input_var_names=['input'], input_var_names=['input'],
session_init=ParamRestore(resNet_param), session_init=ParamRestore(params),
output_var_names=['prob_output'] output_var_names=['prob_output']
) )
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
remap_func = lambda x: (x - image_mean * 255)
im = cv2.imread(input) im = cv2.imread(input)
im = remap_func(cv2.resize(im, (224,224))) im = cv2.resize(im, (224,224)) - image_mean * 255
im = np.reshape( im, (1, 224, 224, 3)).astype('float32') im = np.reshape( im, (1, 224, 224, 3)).astype('float32')
input = [im] prob = predict_func([im])[0]
prob = predict_func(input)[0]
ret = prob[0].argsort()[-10:][::-1] ret = prob[0].argsort()[-10:][::-1]
print(ret) print(ret)
meta = ILSVRCMeta().get_synset_words_1000() meta = ILSVRCMeta().get_synset_words_1000()
print([meta[k] for k in ret]) print([meta[k] for k in ret])
def name_conversion(caffe_layer_name):
def caffeResNet2tensorpackResNet(caffe_layer_name): # beginning & end mapping
# begining & ending stage NAME_MAP = {'bn_conv1/beta': 'conv0/bn/beta',
name_map = {'bn_conv1/beta': 'conv0/bn/beta',
'bn_conv1/gamma': 'conv0/bn/gamma', 'bn_conv1/gamma': 'conv0/bn/gamma',
'bn_conv1/mean/EMA': 'conv0/bn/mean/EMA', 'bn_conv1/mean/EMA': 'conv0/bn/mean/EMA',
'bn_conv1/variance/EMA': 'conv0/bn/variance/EMA', 'bn_conv1/variance/EMA': 'conv0/bn/variance/EMA',
...@@ -125,71 +117,34 @@ def caffeResNet2tensorpackResNet(caffe_layer_name): ...@@ -125,71 +117,34 @@ def caffeResNet2tensorpackResNet(caffe_layer_name):
'conv1/b': 'conv0/b', 'conv1/b': 'conv0/b',
'fc1000/W': 'fc1000/W', 'fc1000/W': 'fc1000/W',
'fc1000/b': 'fc1000/b'} 'fc1000/b': 'fc1000/b'}
if caffe_layer_ in name_map: if caffe_layer_name in NAME_MAP:
print(caffe_layer_name + ' --> ' + name_map[caffe_layer_name]) return NAME_MAP[caffe_layer_name]
return name_map[caffe_layer_name]
print(caffe_layer_name)
layer_id = None
layer_type = None
layer_block = None
layer_branch = None
layer_group = None
s = re.search('([a-z]*)([0-9]*)([a-z]*)_branch([0-9])([a-z])', caffe_layer_name, re.IGNORECASE)
if s == None:
s = re.search('([a-z]*)([0-9]*)([a-z]*)_branch([0-9])', caffe_layer_name, re.IGNORECASE)
else:
layer_id = s.group(5)
if s.group(0) == caffe_layer_name[0:caffe_layer_name.index('/')]: s = re.search('([a-z]+)([0-9]+)([a-z]+)_', caffe_layer_name)
layer_type = s.group(1) if s is None:
layer_group = s.group(2) s = re.search('([a-z]+)([0-9]+)([a-z]+)([0-9]+)_', caffe_layer_name)
layer_block = ord(s.group(3)) - ord('a')
layer_branch = s.group(4)
else:
# print('s group ' + s.group(0))
s = re.search('([a-z]*)([0-9]*)([a-z]*)([0-9]*)_branch([0-9])([a-z])', caffe_layer_name, re.IGNORECASE)
if s == None:
s = re.search('([a-z]*)([0-9]*)([a-z]*)([0-9]*)_branch([0-9])', caffe_layer_name, re.IGNORECASE)
else:
layer_id = s.group(6)
layer_type = s.group(1)
layer_group = s.group(2)
layer_block_part1 = s.group(3) layer_block_part1 = s.group(3)
layer_block_part2 = s.group(4) layer_block_part2 = s.group(4)
if layer_block_part1 == 'a': assert layer_block_part1 in ['a', 'b']
layer_block = 0 layer_block = 0 if layer_block_part1 == 'a' else int(layer_block_part2)
elif layer_block_part1 == 'b': else:
layer_block = int(layer_block_part2) layer_block = ord(s.group(3)) - ord('a')
else: layer_type = s.group(1)
print('model block error!') layer_group = s.group(2)
layer_branch = s.group(5)
if s.group(0) != caffe_layer_name[0:caffe_layer_name.index('/')]:
print('model depth error!')
# TODO error handling
layer_branch = int(re.search('_branch([0-9])', caffe_layer_name).group(1))
assert layer_branch in [1, 2]
if layer_branch == 2:
layer_id = re.search('_branch[0-9]([a-z])/', caffe_layer_name).group(1)
layer_id = ord(layer_id) - ord('a') + 1
type_dict = {'res': '/conv', 'bn':'/bn', 'scale':'/bn'} TYPE_DICT = {'res':'conv', 'bn':'bn'}
shortcut_dict = {'res': '/convshortcut', 'bn':'/bnshortcut', 'scale':'/bnshortcut'}
tf_name = caffe_layer_name[caffe_layer_name.index('/'):] tf_name = caffe_layer_name[caffe_layer_name.index('/'):]
layer_type = TYPE_DICT[layer_type] + \
if layer_branch == '2': (str(layer_id) if layer_branch == 2 else 'shortcut')
tf_name = 'group' + str( int(layer_group) - int('2') ) + \ tf_name = 'group{}/block{}/{}'.format(
'/block' + str( layer_block ) + \ int(layer_group) - 2, layer_block, layer_type) + tf_name
type_dict[layer_type] + str( ord(layer_id) - ord('a') + 1) + tf_name
elif layer_branch == '1':
tf_name = 'group' + str( int(layer_group) - int('2') ) + \
'/block' + str(layer_block) + \
shortcut_dict[layer_type] + tf_name
else:
print('renaming error!')
# TODO error handling
print(caffe_layer_name + ' --> ' + tf_name)
return tf_name return tf_name
if __name__ == '__main__': if __name__ == '__main__':
...@@ -206,4 +161,16 @@ if __name__ == '__main__': ...@@ -206,4 +161,16 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# run resNet with given model (in npy format) # run resNet with given model (in npy format)
MODEL_DEPTH = args.depth MODEL_DEPTH = args.depth
run_test(args.load, args.input)
param = np.load(args.load, encoding='latin1').item()
resnet_param = {}
for k, v in six.iteritems(param):
try:
newname = name_conversion(k)
except:
logger.error("Exception when processing caffe layer {}".format(k))
raise
logger.info("Name Transform: " + k + ' --> ' + newname)
resnet_param[newname] = v
run_test(resnet_param, args.input)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment