Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4ee67733
Commit
4ee67733
authored
Aug 07, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dorefa & resnet models
parent
ea093029
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
74 additions
and
34 deletions
+74
-34
examples/Atari2600/common.py
examples/Atari2600/common.py
+1
-1
examples/DoReFa-Net/README.md
examples/DoReFa-Net/README.md
+2
-2
examples/ResNet/README.md
examples/ResNet/README.md
+2
-4
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+1
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+14
-13
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+54
-13
No files found.
examples/Atari2600/common.py
View file @
4ee67733
...
...
@@ -21,7 +21,7 @@ def play_one_episode(player, func, verbose=False):
def
f
(
s
):
spc
=
player
.
get_action_space
()
act
=
func
([[
s
]])[
0
][
0
]
.
argmax
()
if
random
.
random
()
<
0.01
:
if
random
.
random
()
<
0.0
0
1
:
act
=
spc
.
sample
()
if
verbose
:
print
(
act
)
...
...
examples/DoReFa-Net/README.md
View file @
4ee67733
...
...
@@ -5,8 +5,8 @@ Code and model for the paper:
We hosted a demo at CVPR16 on behalf of Megvii, Inc, running real-time half-VGG size DoReFa-Net on both ARM and FPGA.
But we're not planning to release those runtime bit-op libraries for now. In these examples, bit operations are run in float32.
Pretrained model for 1-2-6-AlexNet is available
at
[
google drive
](
https://drive.google.com/a/
megvii.com/folderview?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ
).
Pretrained model for 1-2-6-AlexNet is available
[
here
](
https://github.com/ppwwyyxx/tensorpack/releases/tag/alexnet-dorefa
)
.
It's provided in the format of numpy dictionary, so it should be very easy to port into other applications.
## Preparation:
...
...
examples/ResNet/README.md
View file @
4ee67733
...
...
@@ -9,7 +9,5 @@ The validation error here is computed on test set.

<!--
-Download model:
-
[
Cifar10 n=18
](
https://drive.google.com/open?id=0B308TeQzmFDLeHpSaHAxWGV1WDg
)
-->
Download model:
[
Cifar10 ResNet-110 (n=18)
](
https://github.com/ppwwyyxx/tensorpack/releases/tag/cifar10-resnet-110
)
tensorpack/callbacks/common.py
View file @
4ee67733
...
...
@@ -36,7 +36,7 @@ class ModelSaver(Callback):
vars
=
tf
.
all_variables
()
var_dict
=
{}
for
v
in
vars
:
name
=
v
.
op
.
name
name
=
v
.
name
if
re
.
match
(
'tower[p1-9]'
,
name
):
#logger.info("Skip {} when saving model.".format(name))
continue
...
...
tensorpack/tfutils/sessinit.py
View file @
4ee67733
...
...
@@ -11,7 +11,7 @@ import six
from
..utils
import
logger
,
EXTRA_SAVE_VARS_KEY
from
.common
import
get_op_var_name
from
.varmanip
import
SessionUpdate
from
.varmanip
import
SessionUpdate
,
get_savename_from_varname
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
'ParamRestore'
,
'ChainInit'
,
...
...
@@ -112,19 +112,17 @@ class SaverRestore(SessionInit):
var_dict
=
defaultdict
(
list
)
chkpt_vars_used
=
set
()
for
v
in
vars_to_restore
:
name
=
v
.
op
.
name
if
'towerp'
in
name
:
logger
.
error
(
"No variable should be under 'towerp' name scope"
.
format
(
v
.
name
))
# don't overwrite anything in the current prediction graph
continue
if
'tower'
in
name
:
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
if
self
.
prefix
and
name
.
startswith
(
self
.
prefix
):
name
=
name
[
len
(
self
.
prefix
)
+
1
:]
name
=
get_savename_from_varname
(
v
.
name
,
varname_prefix
=
self
.
prefix
)
# try to load both 'varname' and 'opname' from checkpoint
# because some old checkpoint might not have ':0'
if
name
in
vars_available
:
var_dict
[
name
]
.
append
(
v
)
chkpt_vars_used
.
add
(
name
)
#vars_available.remove(name)
elif
name
.
endswith
(
':0'
):
name
=
name
[:
-
2
]
if
name
in
vars_available
:
var_dict
[
name
]
.
append
(
v
)
chkpt_vars_used
.
add
(
name
)
else
:
logger
.
warn
(
"Variable {} in the graph not found in checkpoint!"
.
format
(
v
.
op
.
name
))
if
len
(
chkpt_vars_used
)
<
len
(
vars_available
):
...
...
@@ -141,12 +139,13 @@ class ParamRestore(SessionInit):
"""
:param param_dict: a dict of {name: value}
"""
# use varname (with :0) for consistency
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
def
_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
VARIABLES
)
variable_names
=
set
([
k
.
name
for
k
in
variables
])
variable_names
=
set
([
get_savename_from_varname
(
k
.
name
)
for
k
in
variables
])
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
intersect
=
variable_names
&
param_names
...
...
@@ -159,7 +158,9 @@ class ParamRestore(SessionInit):
logger
.
warn
(
"Variable {} in the dict not found in the graph!"
.
format
(
k
))
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
\
get_savename_from_varname
(
v
.
name
)
in
intersect
])
logger
.
info
(
"Restoring from dict ..."
)
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
...
...
tensorpack/tfutils/varmanip.py
View file @
4ee67733
...
...
@@ -5,10 +5,37 @@
import
six
import
tensorflow
as
tf
from
collections
import
defaultdict
import
re
import
numpy
as
np
from
..utils
import
logger
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
]
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
'get_savename_from_varname'
]
def
get_savename_from_varname
(
varname
,
varname_prefix
=
None
,
savename_prefix
=
None
):
"""
:param varname: a variable name in the graph
:param varname_prefix: an optional prefix that may need to be removed in varname
:param savename_prefix: an optional prefix to append to all savename
:returns: the name used to save the variable
"""
name
=
varname
if
'towerp'
in
name
:
logger
.
error
(
"No variable should be under 'towerp' name scope"
.
format
(
v
.
name
))
# don't overwrite anything in the current prediction graph
return
None
if
'tower'
in
name
:
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
if
varname_prefix
is
not
None
\
and
name
.
startswith
(
varname_prefix
):
name
=
name
[
len
(
varname_prefix
)
+
1
:]
if
savename_prefix
is
not
None
:
name
=
savename_prefix
+
'/'
+
name
return
name
class
SessionUpdate
(
object
):
""" Update the variables in a session """
...
...
@@ -17,10 +44,14 @@ class SessionUpdate(object):
:param vars_to_update: a collection of variables to update
"""
self
.
sess
=
sess
self
.
assign_ops
=
{}
self
.
assign_ops
=
defaultdict
(
list
)
for
v
in
vars_to_update
:
p
=
tf
.
placeholder
(
v
.
dtype
,
shape
=
v
.
get_shape
())
self
.
assign_ops
[
v
.
name
]
=
(
p
,
v
.
assign
(
p
))
#p = tf.placeholder(v.dtype, shape=v.get_shape())
with
tf
.
device
(
'/cpu:0'
):
p
=
tf
.
placeholder
(
v
.
dtype
)
savename
=
get_savename_from_varname
(
v
.
name
)
# multiple vars might share one savename
self
.
assign_ops
[
savename
]
.
append
((
p
,
v
,
v
.
assign
(
p
)))
def
update
(
self
,
prms
):
"""
...
...
@@ -28,15 +59,25 @@ class SessionUpdate(object):
Any name in prms must be in the graph and in vars_to_update.
"""
for
name
,
value
in
six
.
iteritems
(
prms
):
p
,
op
=
self
.
assign_ops
[
name
]
varshape
=
tuple
(
p
.
get_shape
()
.
as_list
())
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
),
\
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
assert
name
in
self
.
assign_ops
for
p
,
v
,
op
in
self
.
assign_ops
[
name
]:
if
'fc0/W'
in
name
:
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
varshape
=
tuple
(
v
.
get_shape
()
.
as_list
())
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
),
\
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
if
'fc0/W'
in
name
:
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
if
'fc0/W'
in
name
:
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
def
dump_session_params
(
path
):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment