Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5beab907
Commit
5beab907
authored
Apr 09, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[breaking] rename ParamRestore to DictRestore
parent
6ac34dfb
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
26 additions
and
17 deletions
+26
-17
docs/tutorial/faq.md
docs/tutorial/faq.md
+2
-1
examples/ConvolutionalPoseMachines/load-cpm.py
examples/ConvolutionalPoseMachines/load-cpm.py
+1
-1
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+1
-1
examples/DoReFa-Net/resnet-dorefa.py
examples/DoReFa-Net/resnet-dorefa.py
+1
-1
examples/ResNet/load-resnet.py
examples/ResNet/load-resnet.py
+2
-2
examples/load-alexnet.py
examples/load-alexnet.py
+1
-1
examples/load-vgg16.py
examples/load-vgg16.py
+1
-1
scripts/dump-model-params.py
scripts/dump-model-params.py
+1
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+4
-1
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+1
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+10
-4
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+1
-1
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+0
-1
No files found.
docs/tutorial/faq.md
View file @
5beab907
...
...
@@ -12,6 +12,7 @@ for more details.
It you think:
1.
The framework has limitation so your XYZ cannot be supported, OR
2.
Your XYZ is very common, or very well-defined, so it would be nice to include it.
Then it's a good time to open an issue.
## How to dump/inspect a model
...
...
@@ -25,7 +26,7 @@ expects a path without the extension.
You can dump a cleaner version of the model (with only model/trainable variables), with
`scripts/dump-model-params.py`
, as a simple
`var-name: value`
dict saved in npy format.
I
t expects a metagraph file which is also saved by
`ModelSaver`
.
The scrip
t expects a metagraph file which is also saved by
`ModelSaver`
.
## How to load a model / do transfer learning
...
...
examples/ConvolutionalPoseMachines/load-cpm.py
View file @
5beab907
...
...
@@ -108,7 +108,7 @@ def run_test(model_path, img_file):
param_dict
=
np
.
load
(
model_path
,
encoding
=
'latin1'
)
.
item
()
predict_func
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
session_init
=
Param
Restore
(
param_dict
),
session_init
=
Dict
Restore
(
param_dict
),
input_names
=
[
'input'
],
output_names
=
[
'resized_map'
]
))
...
...
examples/DoReFa-Net/alexnet-dorefa.py
View file @
5beab907
...
...
@@ -308,7 +308,7 @@ if __name__ == '__main__':
if
args
.
run
:
assert
args
.
load
.
endswith
(
'.npy'
)
run_image
(
Model
(),
Param
Restore
(
np
.
load
(
args
.
load
,
encoding
=
'latin1'
)
.
item
()),
args
.
run
)
run_image
(
Model
(),
Dict
Restore
(
np
.
load
(
args
.
load
,
encoding
=
'latin1'
)
.
item
()),
args
.
run
)
sys
.
exit
()
assert
args
.
gpu
is
not
None
,
"Need to specify a list of gpu for training!"
...
...
examples/DoReFa-Net/resnet-dorefa.py
View file @
5beab907
...
...
@@ -190,5 +190,5 @@ if __name__ == '__main__':
eval_on_ILSVRC12
(
args
.
load
,
args
.
data
)
elif
args
.
run
:
assert
args
.
load
.
endswith
(
'.npy'
)
run_image
(
Model
(),
Param
Restore
(
run_image
(
Model
(),
Dict
Restore
(
np
.
load
(
args
.
load
,
encoding
=
'latin1'
)
.
item
()),
args
.
run
)
examples/ResNet/load-resnet.py
View file @
5beab907
...
...
@@ -114,7 +114,7 @@ def get_inference_augmentor():
def
run_test
(
params
,
input
):
pred_config
=
PredictConfig
(
model
=
Model
(),
session_init
=
Param
Restore
(
params
),
session_init
=
Dict
Restore
(
params
),
input_names
=
[
'input'
],
output_names
=
[
'prob'
]
)
...
...
@@ -139,7 +139,7 @@ def eval_on_ILSVRC12(params, data_dir):
ds
=
BatchData
(
ds
,
128
,
remainder
=
True
)
pred_config
=
PredictConfig
(
model
=
Model
(),
session_init
=
Param
Restore
(
params
),
session_init
=
Dict
Restore
(
params
),
input_names
=
[
'input'
,
'label'
],
output_names
=
[
'wrong-top1'
,
'wrong-top5'
]
)
...
...
examples/load-alexnet.py
View file @
5beab907
...
...
@@ -56,7 +56,7 @@ def run_test(path, input):
param_dict
=
np
.
load
(
path
,
encoding
=
'latin1'
)
.
item
()
predictor
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
session_init
=
Param
Restore
(
param_dict
),
session_init
=
Dict
Restore
(
param_dict
),
input_names
=
[
'input'
],
output_names
=
[
'prob'
]
))
...
...
examples/load-vgg16.py
View file @
5beab907
...
...
@@ -66,7 +66,7 @@ def run_test(path, input):
param_dict
=
np
.
load
(
path
,
encoding
=
'latin1'
)
.
item
()
predict_func
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
session_init
=
Param
Restore
(
param_dict
),
session_init
=
Dict
Restore
(
param_dict
),
input_names
=
[
'input'
],
output_names
=
[
'prob'
]
# prob:0 is the probability distribution
))
...
...
scripts/dump-model-params.py
View file @
5beab907
...
...
@@ -32,7 +32,7 @@ with tf.Graph().as_default() as G:
# loading...
if
args
.
model
.
endswith
(
'.npy'
):
init
=
sessinit
.
Param
Restore
(
np
.
load
(
args
.
model
)
.
item
())
init
=
sessinit
.
Dict
Restore
(
np
.
load
(
args
.
model
)
.
item
())
else
:
init
=
sessinit
.
SaverRestore
(
args
.
model
)
sess
=
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
...
...
tensorpack/dataflow/common.py
View file @
5beab907
...
...
@@ -123,7 +123,10 @@ class BatchData(ProxyDataFlow):
elif
type
(
dt
)
==
float
:
tp
=
'float32'
else
:
tp
=
dt
.
dtype
try
:
tp
=
dt
.
dtype
except
:
raise
TypeError
(
"Unsupported type to batch: {}"
.
format
(
type
(
dt
)))
try
:
result
.
append
(
np
.
asarray
([
x
[
k
]
for
x
in
data_holder
],
dtype
=
tp
))
...
...
tensorpack/dataflow/dftools.py
View file @
5beab907
...
...
@@ -4,8 +4,8 @@
import
sys
import
os
import
cv2
import
multiprocessing
as
mp
import
cv2
from
six.moves
import
range
from
.base
import
DataFlow
...
...
tensorpack/tfutils/sessinit.py
View file @
5beab907
...
...
@@ -8,12 +8,13 @@ import tensorflow as tf
import
six
from
..utils
import
logger
from
..utils.develop
import
deprecated
from
.common
import
get_op_tensor_name
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
is_training_name
,
get_checkpoint_path
)
__all__
=
[
'SessionInit'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'ParamRestore'
,
'ChainInit'
,
'ParamRestore'
,
'
DictRestore'
,
'
ChainInit'
,
'JustCurrentSession'
,
'get_model_loader'
]
...
...
@@ -156,7 +157,7 @@ class SaverRestoreRelaxed(SaverRestore):
self
.
_match_vars
(
f
)
class
Param
Restore
(
SessionInit
):
class
Dict
Restore
(
SessionInit
):
"""
Restore variables from a dictionary.
"""
...
...
@@ -190,6 +191,11 @@ class ParamRestore(SessionInit):
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
@
deprecated
(
"Use `DictRestore` instead!"
,
"2017-06-01"
)
def
ParamRestore
(
d
):
return
DictRestore
(
d
)
class
ChainInit
(
SessionInit
):
""" Initialize a session by a list of :class:`SessionInit` instance, executed one by one.
This can be useful for, e.g., loading several models from different files
...
...
@@ -221,11 +227,11 @@ def get_model_loader(filename):
Get a corresponding model loader by looking at the file name.
Returns:
SessInit: either a :class:`
Param
Restore` (if name ends with 'npy') or
SessInit: either a :class:`
Dict
Restore` (if name ends with 'npy') or
:class:`SaverRestore` (otherwise).
"""
if
filename
.
endswith
(
'.npy'
):
assert
os
.
path
.
isfile
(
filename
),
filename
return
Param
Restore
(
np
.
load
(
filename
,
encoding
=
'latin1'
)
.
item
())
return
Dict
Restore
(
np
.
load
(
filename
,
encoding
=
'latin1'
)
.
item
())
else
:
return
SaverRestore
(
filename
)
tensorpack/tfutils/varmanip.py
View file @
5beab907
...
...
@@ -119,7 +119,7 @@ class SessionUpdate(object):
def
dump_session_params
(
path
):
"""
Dump value of all TRAINABLE + MODEL variables to a dict, and save as
npy format (loadable by :class:`
Param
Restore`).
npy format (loadable by :class:`
Dict
Restore`).
Args:
path(str): the path to save the parameters.
...
...
tensorpack/utils/argtools.py
View file @
5beab907
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: argtools.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
inspect
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment