Commit 9387c653 authored by Yuxin Wu's avatar Yuxin Wu

explicit 4D tensor for fc/W

parent 60e52b94
...@@ -62,6 +62,11 @@ class ParamRestore(SessionInit): ...@@ -62,6 +62,11 @@ class ParamRestore(SessionInit):
logger.warn("Param {} not found in this graph".format(name)) logger.warn("Param {} not found in this graph".format(name))
continue continue
logger.info("Restoring param {}".format(name)) logger.info("Restoring param {}".format(name))
varshape = tuple(var.get_shape().as_list())
if varshape != value.shape:
assert np.prod(varshape) == np.prod(value.shape)
logger.warn("Param {} is reshaped during loading!".format(name))
value = value.reshape(varshape)
sess.run(var.assign(value)) sess.run(var.assign(value))
def dump_session_params(path): def dump_session_params(path):
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from abc import abstractmethod from abc import abstractmethod
import numpy as np import numpy as np
import copy
import os import os
from six.moves import zip from six.moves import zip
...@@ -15,18 +16,25 @@ from . import logger ...@@ -15,18 +16,25 @@ from . import logger
def get_processor(): def get_processor():
ret = {} ret = {}
def process_conv(layer_name, param): def process_conv(layer_name, param, input_data_shape):
assert len(param) == 2 assert len(param) == 2
# caffe: ch_out, ch_in, h, w # caffe: ch_out, ch_in, h, w
return {layer_name + '/W': param[0].data.transpose(2,3,1,0), return {layer_name + '/W': param[0].data.transpose(2,3,1,0),
layer_name + '/b': param[1].data} layer_name + '/b': param[1].data}
ret['Convolution'] = process_conv ret['Convolution'] = process_conv
# XXX fc after spatial needs a different stuff # TODO caffe has an 'transpose' option for fc/W
# XXX caffe has an 'transpose' option for fc/W def process_fc(layer_name, param, input_data_shape):
def process_fc(layer_name, param):
assert len(param) == 2 assert len(param) == 2
return {layer_name + '/W': param[0].data.transpose(), if len(input_data_shape) == 3:
logger.info("{} is right after spatial data.".format(layer_name))
W = param[0].data
# original: outx(CxHxW)
W = W.reshape((-1,) + input_data_shape).transpose(2,3,1,0)
# become: (HxWxC)xout
else:
W = param[0].data.transpose()
return {layer_name + '/W': W,
layer_name + '/b': param[1].data} layer_name + '/b': param[1].data}
ret['InnerProduct'] = process_fc ret['InnerProduct'] = process_fc
...@@ -46,17 +54,14 @@ def load_caffe(model_desc, model_file): ...@@ -46,17 +54,14 @@ def load_caffe(model_desc, model_file):
layer_names = net._layer_names layer_names = net._layer_names
blob_names = net.blobs.keys() blob_names = net.blobs.keys()
for layername, layer in zip(layer_names, net.layers): for layername, layer in zip(layer_names, net.layers):
if layer.type == 'InnerProduct': try:
prev_blob_name = blob_names[blob_names.index(layername)-1] prev_blob_name = blob_names[blob_names.index(layername)-1]
prev_data_shape = net.blobs[prev_blob_name].data.shape[1:] prev_data_shape = net.blobs[prev_blob_name].data.shape[1:]
if len(prev_data_shape) == 3: except ValueError:
logger.info("{} is right after spatial data.".format(layername)) prev_data_shape = None
layer.blobs[0].data[:] = layer.blobs[0].data.reshape(
(-1, ) + prev_data_shape).transpose(
0,2,3,1).reshape(
(-1, np.prod(prev_data_shape)))
if layer.type in param_processors: if layer.type in param_processors:
param_dict.update(param_processors[layer.type](layername, layer.blobs)) param_dict.update(param_processors[layer.type](
layername, layer.blobs, prev_data_shape))
else: else:
assert len(layer.blobs) == 0, len(layer.blobs) assert len(layer.blobs) == 0, len(layer.blobs)
logger.info("Model loaded from caffe. Params: " + \ logger.info("Model loaded from caffe. Params: " + \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment