Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
df89a95f
Commit
df89a95f
authored
Mar 21, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
find bug about load_caffe
parent
817cd080
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
11 deletions
+38
-11
examples/load_alexnet.py
examples/load_alexnet.py
+9
-6
tensorpack/predict.py
tensorpack/predict.py
+2
-2
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+4
-1
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+23
-2
No files found.
examples/load_alexnet.py
View file @
df89a95f
...
...
@@ -115,7 +115,7 @@ def get_config():
max_epoch
=
100
,
)
def
run_test
(
path
):
def
run_test
(
path
,
input
):
param_dict
=
np
.
load
(
path
)
.
item
()
pred_config
=
PredictConfig
(
...
...
@@ -127,25 +127,28 @@ def run_test(path):
predict_func
=
get_predict_func
(
pred_config
)
import
cv2
im
=
cv2
.
imread
(
'cat.jpg'
)
im
=
cv2
.
imread
(
input
)
assert
im
is
not
None
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im
=
cv2
.
resize
(
im
,
(
227
,
227
))
im
=
np
.
reshape
(
im
,
(
1
,
227
,
227
,
3
))
im
=
np
.
reshape
(
im
,
(
1
,
227
,
227
,
3
))
.
astype
(
'float32'
)
outputs
=
predict_func
([
im
])[
0
]
prob
=
outputs
[
0
]
print
prob
.
shape
ret
=
prob
.
argsort
()[
-
10
:][::
-
1
]
print
ret
assert
ret
[
0
]
==
285
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'.npy model file generated by tensorpack.utils.loadcaffe'
,
required
=
True
)
parser
.
add_argument
(
'--input'
,
help
=
'an input image'
,
required
=
True
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
#start_train(get_config())
# run alexnet with given model (in npy format)
run_test
(
'alexnet.npy'
)
run_test
(
args
.
load
,
args
.
input
)
tensorpack/predict.py
View file @
df89a95f
...
...
@@ -10,9 +10,9 @@ import numpy as np
from
tqdm
import
tqdm
from
six.moves
import
zip
from
.utils
import
*
from
.utils.modelutils
import
describe_model
from
.tfutils
import
*
from
.utils
import
logger
from
.tfutils.modelutils
import
describe_model
from
.dataflow
import
DataFlow
,
BatchData
class
PredictConfig
(
object
):
...
...
tensorpack/tfutils/sessinit.py
View file @
df89a95f
...
...
@@ -54,6 +54,8 @@ class ParamRestore(SessionInit):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
for
name
,
value
in
six
.
iteritems
(
self
.
prms
):
if
not
name
.
endswith
(
':0'
):
name
=
name
+
':0'
try
:
var
=
var_dict
[
name
]
except
(
ValueError
,
KeyError
):
...
...
@@ -67,7 +69,8 @@ def dump_session_params(path):
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
result
=
{}
for
v
in
var
:
result
[
v
.
name
]
=
v
.
eval
()
name
=
v
.
name
.
replace
(
":0"
,
""
)
result
[
name
]
=
v
.
eval
()
logger
.
info
(
"Params to save to {}:"
.
format
(
path
))
logger
.
info
(
str
(
result
.
keys
()))
np
.
save
(
path
,
result
)
tensorpack/utils/loadcaffe.py
View file @
df89a95f
...
...
@@ -5,6 +5,7 @@
from
collections
import
namedtuple
,
defaultdict
from
abc
import
abstractmethod
import
numpy
as
np
import
os
from
six.moves
import
zip
...
...
@@ -21,12 +22,14 @@ def get_processor():
layer_name
+
'/b'
:
param
[
1
]
.
data
}
ret
[
'Convolution'
]
=
process_conv
# XXX fc after spatial needs a different stuff
# XXX caffe has an 'transpose' option for fc/W
def
process_fc
(
layer_name
,
param
):
assert
len
(
param
)
==
2
return
{
layer_name
+
'/W'
:
param
[
0
]
.
data
.
transpose
(),
layer_name
+
'/b'
:
param
[
1
]
.
data
}
ret
[
'InnerProduct'
]
=
process_fc
return
ret
def
load_caffe
(
model_desc
,
model_file
):
...
...
@@ -38,9 +41,18 @@ def load_caffe(model_desc, model_file):
with
change_env
(
'GLOG_minloglevel'
,
'2'
):
import
caffe
caffe
.
set_mode_cpu
()
net
=
caffe
.
Net
(
model_desc
,
model_file
,
caffe
.
TEST
)
layer_names
=
net
.
_layer_names
for
layername
,
layer
in
zip
(
layer_names
,
net
.
layers
):
# XXX
if
layername
==
'fc6'
:
prev_data_shape
=
(
10
,
256
,
6
,
6
)
logger
.
info
(
"Special FC..."
)
layer
.
blobs
[
0
]
.
data
[:]
=
layer
.
blobs
[
0
]
.
data
.
reshape
(
(
-
1
,
)
+
prev_data_shape
[
1
:])
.
transpose
(
0
,
2
,
3
,
1
)
.
reshape
(
(
-
1
,
np
.
prod
(
prev_data_shape
[
1
:])))
if
layer
.
type
in
param_processors
:
param_dict
.
update
(
param_processors
[
layer
.
type
](
layername
,
layer
.
blobs
))
else
:
...
...
@@ -50,5 +62,14 @@ def load_caffe(model_desc, model_file):
return
param_dict
if
__name__
==
'__main__'
:
ret
=
load_caffe
(
'/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers_deploy.prototxt'
,
'/home/wyx/Work/DL/caffe/models/VGG/VGG_ILSVRC_16_layers.caffemodel'
)
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'model'
)
parser
.
add_argument
(
'weights'
)
parser
.
add_argument
(
'output'
)
args
=
parser
.
parse_args
()
ret
=
load_caffe
(
args
.
model
,
args
.
weights
)
import
numpy
as
np
np
.
save
(
args
.
output
,
ret
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment