Commit 491e7144 authored by Yuxin Wu's avatar Yuxin Wu

loadcaffe supports BN

parent 59cd3c77
......@@ -19,59 +19,91 @@ __all__ = ['load_caffe', 'get_caffe_pb']
CAFFE_PROTO_URL = "https://github.com/BVLC/caffe/raw/master/src/caffe/proto/caffe.proto"
def get_processor():
ret = {}
def process_conv(layer_name, param, input_data_shape):
assert len(param) == 2
class CaffeLayerProcessor(object):
def __init__(self, net):
self.net = net
self.layer_names = net._layer_names
self.param_dict = {}
self.processors = {
'Convolution': self.proc_conv,
'InnerProduct': self.proc_fc,
'BatchNorm': self.proc_bn,
'Scale': self.proc_scale
}
def process(self):
for idx, layer in enumerate(self.net.layers):
param = layer.blobs
name = self.layer_names[idx]
if layer.type in self.processors:
logger.info("Processing layer {} of type {}".format(
name, layer.type))
dic = self.processors[layer.type](idx, name, param)
self.param_dict.update(dic)
elif len(layer.blobs) != 0:
logger.warn(
"{} layer contains parameters but is not supported!".format(layer.type))
return self.param_dict
def proc_conv(self, idx, name, param):
assert len(param) <= 2
assert param[0].data.ndim == 4
# caffe: ch_out, ch_in, h, w
return {layer_name + '/W': param[0].data.transpose(2,3,1,0),
layer_name + '/b': param[1].data}
ret['Convolution'] = process_conv
W = param[0].data.transpose(2,3,1,0)
if len(param) == 1:
return {name + '/W': W}
else:
return {name + '/W': W,
name + '/b': param[1].data}
def proc_fc(self, idx, name, param):
# TODO caffe has an 'transpose' option for fc/W
def process_fc(layer_name, param, input_data_shape):
assert len(param) == 2
if len(input_data_shape) == 3:
logger.info("{} is right after spatial data.".format(layer_name))
prev_layer_name = self.net.bottom_names[name][0]
prev_layer_output = self.net.blobs[prev_layer_name].data
if prev_layer_output.ndim == 4:
logger.info("FC layer {} takes spatial data.".format(name))
W = param[0].data
# original: outx(CxHxW)
W = W.reshape((-1,) + input_data_shape).transpose(2,3,1,0)
W = W.reshape((-1,) + prev_layer_output.shape[1:]).transpose(2,3,1,0)
# become: (HxWxC)xout
else:
W = param[0].data.transpose()
return {layer_name + '/W': W,
layer_name + '/b': param[1].data}
ret['InnerProduct'] = process_fc
return {name + '/W': W,
name + '/b': param[1].data}
def proc_bn(self, idx, name, param):
assert param[2].data[0] == 1.0
return {name +'/mean/EMA': param[0].data,
name +'/variance/EMA': param[1].data }
def proc_scale(self, idx, name, param):
bottom_name = self.net.bottom_names[name][0]
# find te bn layer before this scaling
for i, layer in enumerate(self.net.layers):
if layer.type == 'BatchNorm':
name2 = self.layer_names[i]
bottom_name2 = self.net.bottom_names[name2][0]
if bottom_name2 == bottom_name:
# scaling and BN share the same bottom, should merge
logger.info("Merge {} and {} into one BatchNorm layer".format(
name, name2))
return {name + '/beta': param[1].data,
name + '/gamma': param[0].data }
# assume this scaling layer is part of some BN
logger.error("Could not find a BN layer corresponding to this Scale layer!")
raise ValueError()
return ret
def load_caffe(model_desc, model_file):
"""
return a dict of params
"""
param_dict = {}
param_processors = get_processor()
with change_env('GLOG_minloglevel', '2'):
import caffe
caffe.set_mode_cpu()
net = caffe.Net(model_desc, model_file, caffe.TEST)
layer_names = net._layer_names
blob_names = net.blobs.keys()
for layername, layer in zip(layer_names, net.layers):
try:
prev_blob_name = blob_names[blob_names.index(layername)-1]
prev_data_shape = net.blobs[prev_blob_name].data.shape[1:]
except ValueError:
prev_data_shape = None
logger.info("Processing layer {} of type {}".format(
layername, layer.type))
if layer.type in param_processors:
param_dict.update(param_processors[layer.type](
layername, layer.blobs, prev_data_shape))
else:
if len(layer.blobs) != 0:
logger.warn("Layer type {} not supported!".format(layer.type))
param_dict = CaffeLayerProcessor(net).process()
logger.info("Model loaded from caffe. Params: " + \
" ".join(sorted(param_dict.keys())))
return param_dict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment