Commit 3e512ea6 authored by Yuxin Wu's avatar Yuxin Wu

better handle caffe fc layout

parent c759f211
......@@ -44,15 +44,17 @@ def load_caffe(model_desc, model_file):
caffe.set_mode_cpu()
net = caffe.Net(model_desc, model_file, caffe.TEST)
layer_names = net._layer_names
blob_names = net.blobs.keys()
for layername, layer in zip(layer_names, net.layers):
# XXX
if layername == 'fc6':
prev_data_shape = (10,256,6,6)
logger.info("Special FC...")
layer.blobs[0].data[:] = layer.blobs[0].data.reshape(
(-1, ) + prev_data_shape[1:]).transpose(
0,2,3,1).reshape(
(-1, np.prod(prev_data_shape[1:])))
if layer.type == 'InnerProduct':
prev_blob_name = blob_names[blob_names.index(layername)-1]
prev_data_shape = net.blobs[prev_blob_name].data.shape[1:]
if len(prev_data_shape) == 3:
logger.info("{} is right after spatial data.".format(layername))
layer.blobs[0].data[:] = layer.blobs[0].data.reshape(
(-1, ) + prev_data_shape).transpose(
0,2,3,1).reshape(
(-1, np.prod(prev_data_shape)))
if layer.type in param_processors:
param_dict.update(param_processors[layer.type](layername, layer.blobs))
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment