Commit 3e512ea6 authored by Yuxin Wu's avatar Yuxin Wu

better handle caffe fc layout

parent c759f211
...@@ -44,15 +44,17 @@ def load_caffe(model_desc, model_file): ...@@ -44,15 +44,17 @@ def load_caffe(model_desc, model_file):
caffe.set_mode_cpu() caffe.set_mode_cpu()
net = caffe.Net(model_desc, model_file, caffe.TEST) net = caffe.Net(model_desc, model_file, caffe.TEST)
layer_names = net._layer_names layer_names = net._layer_names
blob_names = net.blobs.keys()
for layername, layer in zip(layer_names, net.layers): for layername, layer in zip(layer_names, net.layers):
# XXX if layer.type == 'InnerProduct':
if layername == 'fc6': prev_blob_name = blob_names[blob_names.index(layername)-1]
prev_data_shape = (10,256,6,6) prev_data_shape = net.blobs[prev_blob_name].data.shape[1:]
logger.info("Special FC...") if len(prev_data_shape) == 3:
layer.blobs[0].data[:] = layer.blobs[0].data.reshape( logger.info("{} is right after spatial data.".format(layername))
(-1, ) + prev_data_shape[1:]).transpose( layer.blobs[0].data[:] = layer.blobs[0].data.reshape(
0,2,3,1).reshape( (-1, ) + prev_data_shape).transpose(
(-1, np.prod(prev_data_shape[1:]))) 0,2,3,1).reshape(
(-1, np.prod(prev_data_shape)))
if layer.type in param_processors: if layer.type in param_processors:
param_dict.update(param_processors[layer.type](layername, layer.blobs)) param_dict.update(param_processors[layer.type](layername, layer.blobs))
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment