Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
af2c0e9c
Commit
af2c0e9c
authored
Jul 29, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
session update
parent
3a431489
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
27 deletions
+58
-27
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+19
-26
tensorpack/tfutils/sessupdate.py
tensorpack/tfutils/sessupdate.py
+37
-0
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
No files found.
tensorpack/tfutils/sessinit.py
View file @
af2c0e9c
...
@@ -11,6 +11,8 @@ import tensorflow as tf
...
@@ -11,6 +11,8 @@ import tensorflow as tf
import
six
import
six
from
..utils
import
logger
,
EXTRA_SAVE_VARS_KEY
from
..utils
import
logger
,
EXTRA_SAVE_VARS_KEY
from
.common
import
get_op_var_name
from
.sessupdate
import
SessionUpdate
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
'ParamRestore'
,
'ChainInit'
,
'ParamRestore'
,
'ChainInit'
,
...
@@ -142,35 +144,26 @@ class ParamRestore(SessionInit):
...
@@ -142,35 +144,26 @@ class ParamRestore(SessionInit):
"""
"""
:param param_dict: a dict of {name: value}
:param param_dict: a dict of {name: value}
"""
"""
self
.
prms
=
param_dict
self
.
prms
=
{
get_op_var_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
# allow restore non-trainable variables
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
VARIABLES
)
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
VARIABLES
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
for
name
,
value
in
six
.
iteritems
(
self
.
prms
):
variable_names
=
set
([
k
.
name
for
k
in
variables
])
if
not
name
.
endswith
(
':0'
):
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
name
=
name
+
':0'
try
:
intersect
=
variable_names
and
param_names
var
=
var_dict
[
name
]
except
(
ValueError
,
KeyError
):
logger
.
info
(
"Params to restore: {}"
.
format
(
logger
.
warn
(
"Param {} not found in this graph"
.
format
(
name
))
', '
.
join
(
map
(
str
,
intersect
))))
continue
for
k
in
variable_names
-
param_names
:
del
var_dict
[
name
]
logger
.
warn
(
"Variable {} in the graph won't be restored!"
.
format
(
k
))
logger
.
info
(
"Restoring param {}"
.
format
(
name
))
for
k
in
param_names
-
variable_names
:
varshape
=
tuple
(
var
.
get_shape
()
.
as_list
())
logger
.
warn
(
"Param {} not found in this graph!"
.
format
(
k
))
if
varshape
!=
value
.
shape
:
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
# TODO only allow reshape when set(shape) is the same or different by 1
logger
.
info
(
"Restoring from param dict ..."
)
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
),
\
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during loading!"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
# assign(value) creates ops with values being saved, doubling the size of metagraph
# assign(placeholder) works better here
p
=
tf
.
placeholder
(
value
.
dtype
,
shape
=
value
.
shape
)
sess
.
run
(
var
.
assign
(
p
),
feed_dict
=
{
p
:
value
})
if
var_dict
:
logger
.
warn
(
"Some variables in the graph are not restored: {}"
.
format
(
str
(
var_dict
)))
def
ChainInit
(
SessionInit
):
def
ChainInit
(
SessionInit
):
""" Init a session by a list of SessionInit instance."""
""" Init a session by a list of SessionInit instance."""
...
...
tensorpack/tfutils/sessupdate.py
0 → 100644
View file @
af2c0e9c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: sessupdate.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
six
import
tensorflow
as
tf
__all__
=
[
'SessionUpdate'
]
class
SessionUpdate
(
object
):
""" Update the variables in a session """
def
__init__
(
self
,
sess
,
vars_to_update
):
"""
:param vars_to_update: a collection of variables to update
"""
self
.
sess
=
sess
self
.
assign_ops
=
{}
for
v
in
vars_to_update
:
p
=
tf
.
placeholder
(
v
.
dtype
,
shape
=
v
.
get_shape
())
self
.
assign_ops
[
v
.
name
]
=
(
p
,
v
.
assign
(
p
))
def
update
(
self
,
prms
):
"""
:param prms: dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
"""
for
name
,
value
in
six
.
iteritems
(
prms
):
p
,
op
=
self
.
assign_ops
[
name
]
varshape
=
tuple
(
p
.
get_shape
()
.
as_list
())
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
),
\
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
tensorpack/tfutils/summary.py
View file @
af2c0e9c
...
@@ -104,6 +104,7 @@ def summary_moving_average():
...
@@ -104,6 +104,7 @@ def summary_moving_average():
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
for
idx
,
c
in
enumerate
(
vars_to_summary
):
for
idx
,
c
in
enumerate
(
vars_to_summary
):
# TODO assert scalar
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
scalar_summary
(
name
,
averager
.
average
(
c
))
tf
.
scalar_summary
(
name
,
averager
.
average
(
c
))
return
avg_maintain_op
return
avg_maintain_op
...
...
tensorpack/train/trainer.py
View file @
af2c0e9c
...
@@ -62,7 +62,7 @@ class SimpleTrainer(Trainer):
...
@@ -62,7 +62,7 @@ class SimpleTrainer(Trainer):
model
=
self
.
model
model
=
self
.
model
self
.
input_vars
=
model
.
get_input_vars
()
self
.
input_vars
=
model
.
get_input_vars
()
model
.
build_graph
(
self
.
input_vars
,
True
)
model
.
build_graph
(
self
.
input_vars
,
True
)
cost_var
=
model
.
get_cost
()
cost_var
=
model
.
get_cost
()
# TODO assert scalar
add_moving_summary
(
cost_var
)
add_moving_summary
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment