Commit df89a95f authored by Yuxin Wu's avatar Yuxin Wu

find bug about load_caffe

parent 817cd080
......@@ -115,7 +115,7 @@ def get_config():
max_epoch=100,
)
def run_test(path):
def run_test(path, input):
param_dict = np.load(path).item()
pred_config = PredictConfig(
......@@ -127,25 +127,28 @@ def run_test(path):
predict_func = get_predict_func(pred_config)
import cv2
im = cv2.imread('cat.jpg')
im = cv2.imread(input)
assert im is not None
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (227, 227))
im = np.reshape(im, (1, 227, 227, 3))
im = np.reshape(im, (1, 227, 227, 3)).astype('float32')
outputs = predict_func([im])[0]
prob = outputs[0]
print prob.shape
ret = prob.argsort()[-10:][::-1]
print ret
assert ret[0] == 285
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load',
help='.npy model file generated by tensorpack.utils.loadcaffe',
required=True)
parser.add_argument('--input', help='an input image', required=True)
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#start_train(get_config())
# run alexnet with given model (in npy format)
run_test('alexnet.npy')
run_test(args.load, args.input)
......@@ -10,9 +10,9 @@ import numpy as np
from tqdm import tqdm
from six.moves import zip
from .utils import *
from .utils.modelutils import describe_model
from .tfutils import *
from .utils import logger
from .tfutils.modelutils import describe_model
from .dataflow import DataFlow, BatchData
class PredictConfig(object):
......
......@@ -54,6 +54,8 @@ class ParamRestore(SessionInit):
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_dict = dict([v.name, v] for v in variables)
for name, value in six.iteritems(self.prms):
if not name.endswith(':0'):
name = name + ':0'
try:
var = var_dict[name]
except (ValueError, KeyError):
......@@ -67,7 +69,8 @@ def dump_session_params(path):
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
result = {}
for v in var:
result[v.name] = v.eval()
name = v.name.replace(":0", "")
result[name] = v.eval()
logger.info("Params to save to {}:".format(path))
logger.info(str(result.keys()))
np.save(path, result)
......@@ -5,6 +5,7 @@
from collections import namedtuple, defaultdict
from abc import abstractmethod
import numpy as np
import os
from six.moves import zip
......@@ -21,12 +22,14 @@ def get_processor():
layer_name + '/b': param[1].data}
ret['Convolution'] = process_conv
# XXX fc after spatial needs a different stuff
# XXX caffe has an 'transpose' option for fc/W
def process_fc(layer_name, param):
assert len(param) == 2
return {layer_name + '/W': param[0].data.transpose(),
layer_name + '/b': param[1].data}
ret['InnerProduct'] = process_fc
return ret
def load_caffe(model_desc, model_file):
......@@ -38,9 +41,18 @@ def load_caffe(model_desc, model_file):
with change_env('GLOG_minloglevel', '2'):
import caffe
caffe.set_mode_cpu()
net = caffe.Net(model_desc, model_file, caffe.TEST)
layer_names = net._layer_names
for layername, layer in zip(layer_names, net.layers):
# XXX
if layername == 'fc6':
prev_data_shape = (10,256,6,6)
logger.info("Special FC...")
layer.blobs[0].data[:] = layer.blobs[0].data.reshape(
(-1, ) + prev_data_shape[1:]).transpose(
0,2,3,1).reshape(
(-1, np.prod(prev_data_shape[1:])))
if layer.type in param_processors:
param_dict.update(param_processors[layer.type](layername, layer.blobs))
else:
......@@ -50,5 +62,14 @@ def load_caffe(model_desc, model_file):
return param_dict
if __name__ == '__main__':
ret = load_caffe('/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers_deploy.prototxt',
'/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers.caffemodel')
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model')
parser.add_argument('weights')
parser.add_argument('output')
args = parser.parse_args()
ret = load_caffe(args.model, args.weights)
import numpy as np
np.save(args.output, ret)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment