Commit df89a95f authored by Yuxin Wu's avatar Yuxin Wu

find bug about load_caffe

parent 817cd080
...@@ -115,7 +115,7 @@ def get_config(): ...@@ -115,7 +115,7 @@ def get_config():
max_epoch=100, max_epoch=100,
) )
def run_test(path): def run_test(path, input):
param_dict = np.load(path).item() param_dict = np.load(path).item()
pred_config = PredictConfig( pred_config = PredictConfig(
...@@ -127,25 +127,28 @@ def run_test(path): ...@@ -127,25 +127,28 @@ def run_test(path):
predict_func = get_predict_func(pred_config) predict_func = get_predict_func(pred_config)
import cv2 import cv2
im = cv2.imread('cat.jpg') im = cv2.imread(input)
assert im is not None
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (227, 227)) im = cv2.resize(im, (227, 227))
im = np.reshape(im, (1, 227, 227, 3)) im = np.reshape(im, (1, 227, 227, 3)).astype('float32')
outputs = predict_func([im])[0] outputs = predict_func([im])[0]
prob = outputs[0] prob = outputs[0]
print prob.shape print prob.shape
ret = prob.argsort()[-10:][::-1] ret = prob.argsort()[-10:][::-1]
print ret print ret
assert ret[0] == 285
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load',
help='.npy model file generated by tensorpack.utils.loadcaffe',
required=True)
parser.add_argument('--input', help='an input image', required=True)
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#start_train(get_config()) #start_train(get_config())
# run alexnet with given model (in npy format) # run alexnet with given model (in npy format)
run_test('alexnet.npy') run_test(args.load, args.input)
...@@ -10,9 +10,9 @@ import numpy as np ...@@ -10,9 +10,9 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from six.moves import zip from six.moves import zip
from .utils import * from .tfutils import *
from .utils.modelutils import describe_model
from .utils import logger from .utils import logger
from .tfutils.modelutils import describe_model
from .dataflow import DataFlow, BatchData from .dataflow import DataFlow, BatchData
class PredictConfig(object): class PredictConfig(object):
......
...@@ -54,6 +54,8 @@ class ParamRestore(SessionInit): ...@@ -54,6 +54,8 @@ class ParamRestore(SessionInit):
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
var_dict = dict([v.name, v] for v in variables) var_dict = dict([v.name, v] for v in variables)
for name, value in six.iteritems(self.prms): for name, value in six.iteritems(self.prms):
if not name.endswith(':0'):
name = name + ':0'
try: try:
var = var_dict[name] var = var_dict[name]
except (ValueError, KeyError): except (ValueError, KeyError):
...@@ -67,7 +69,8 @@ def dump_session_params(path): ...@@ -67,7 +69,8 @@ def dump_session_params(path):
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
result = {} result = {}
for v in var: for v in var:
result[v.name] = v.eval() name = v.name.replace(":0", "")
result[name] = v.eval()
logger.info("Params to save to {}:".format(path)) logger.info("Params to save to {}:".format(path))
logger.info(str(result.keys())) logger.info(str(result.keys()))
np.save(path, result) np.save(path, result)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from abc import abstractmethod from abc import abstractmethod
import numpy as np
import os import os
from six.moves import zip from six.moves import zip
...@@ -21,12 +22,14 @@ def get_processor(): ...@@ -21,12 +22,14 @@ def get_processor():
layer_name + '/b': param[1].data} layer_name + '/b': param[1].data}
ret['Convolution'] = process_conv ret['Convolution'] = process_conv
# XXX fc after spatial needs a different stuff
# XXX caffe has an 'transpose' option for fc/W # XXX caffe has an 'transpose' option for fc/W
def process_fc(layer_name, param): def process_fc(layer_name, param):
assert len(param) == 2 assert len(param) == 2
return {layer_name + '/W': param[0].data.transpose(), return {layer_name + '/W': param[0].data.transpose(),
layer_name + '/b': param[1].data} layer_name + '/b': param[1].data}
ret['InnerProduct'] = process_fc ret['InnerProduct'] = process_fc
return ret return ret
def load_caffe(model_desc, model_file): def load_caffe(model_desc, model_file):
...@@ -38,9 +41,18 @@ def load_caffe(model_desc, model_file): ...@@ -38,9 +41,18 @@ def load_caffe(model_desc, model_file):
with change_env('GLOG_minloglevel', '2'): with change_env('GLOG_minloglevel', '2'):
import caffe import caffe
caffe.set_mode_cpu()
net = caffe.Net(model_desc, model_file, caffe.TEST) net = caffe.Net(model_desc, model_file, caffe.TEST)
layer_names = net._layer_names layer_names = net._layer_names
for layername, layer in zip(layer_names, net.layers): for layername, layer in zip(layer_names, net.layers):
# XXX
if layername == 'fc6':
prev_data_shape = (10,256,6,6)
logger.info("Special FC...")
layer.blobs[0].data[:] = layer.blobs[0].data.reshape(
(-1, ) + prev_data_shape[1:]).transpose(
0,2,3,1).reshape(
(-1, np.prod(prev_data_shape[1:])))
if layer.type in param_processors: if layer.type in param_processors:
param_dict.update(param_processors[layer.type](layername, layer.blobs)) param_dict.update(param_processors[layer.type](layername, layer.blobs))
else: else:
...@@ -50,5 +62,14 @@ def load_caffe(model_desc, model_file): ...@@ -50,5 +62,14 @@ def load_caffe(model_desc, model_file):
return param_dict return param_dict
if __name__ == '__main__': if __name__ == '__main__':
ret = load_caffe('/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers_deploy.prototxt', import argparse
'/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers.caffemodel') parser = argparse.ArgumentParser()
parser.add_argument('model')
parser.add_argument('weights')
parser.add_argument('output')
args = parser.parse_args()
ret = load_caffe(args.model, args.weights)
import numpy as np
np.save(args.output, ret)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment