Commit 9387c653 authored by Yuxin Wu's avatar Yuxin Wu

explicit 4D tensor for fc/W

parent 60e52b94
......@@ -62,6 +62,11 @@ class ParamRestore(SessionInit):
logger.warn("Param {} not found in this graph".format(name))
continue
logger.info("Restoring param {}".format(name))
varshape = tuple(var.get_shape().as_list())
if varshape != value.shape:
assert np.prod(varshape) == np.prod(value.shape)
logger.warn("Param {} is reshaped during loading!".format(name))
value = value.reshape(varshape)
sess.run(var.assign(value))
def dump_session_params(path):
......
......@@ -6,6 +6,7 @@
from collections import namedtuple, defaultdict
from abc import abstractmethod
import numpy as np
import copy
import os
from six.moves import zip
......@@ -15,18 +16,25 @@ from . import logger
def get_processor():
ret = {}
def process_conv(layer_name, param):
def process_conv(layer_name, param, input_data_shape):
assert len(param) == 2
# caffe: ch_out, ch_in, h, w
return {layer_name + '/W': param[0].data.transpose(2,3,1,0),
layer_name + '/b': param[1].data}
ret['Convolution'] = process_conv
# XXX fc after spatial needs a different stuff
# XXX caffe has an 'transpose' option for fc/W
def process_fc(layer_name, param):
# TODO caffe has an 'transpose' option for fc/W
def process_fc(layer_name, param, input_data_shape):
assert len(param) == 2
return {layer_name + '/W': param[0].data.transpose(),
if len(input_data_shape) == 3:
logger.info("{} is right after spatial data.".format(layer_name))
W = param[0].data
# original: outx(CxHxW)
W = W.reshape((-1,) + input_data_shape).transpose(2,3,1,0)
# become: (HxWxC)xout
else:
W = param[0].data.transpose()
return {layer_name + '/W': W,
layer_name + '/b': param[1].data}
ret['InnerProduct'] = process_fc
......@@ -46,17 +54,14 @@ def load_caffe(model_desc, model_file):
layer_names = net._layer_names
blob_names = net.blobs.keys()
for layername, layer in zip(layer_names, net.layers):
if layer.type == 'InnerProduct':
try:
prev_blob_name = blob_names[blob_names.index(layername)-1]
prev_data_shape = net.blobs[prev_blob_name].data.shape[1:]
if len(prev_data_shape) == 3:
logger.info("{} is right after spatial data.".format(layername))
layer.blobs[0].data[:] = layer.blobs[0].data.reshape(
(-1, ) + prev_data_shape).transpose(
0,2,3,1).reshape(
(-1, np.prod(prev_data_shape)))
except ValueError:
prev_data_shape = None
if layer.type in param_processors:
param_dict.update(param_processors[layer.type](layername, layer.blobs))
param_dict.update(param_processors[layer.type](
layername, layer.blobs, prev_data_shape))
else:
assert len(layer.blobs) == 0, len(layer.blobs)
logger.info("Model loaded from caffe. Params: " + \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment