Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f60989d3
Commit
f60989d3
authored
Feb 15, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix bn and rewrite saverrestore with var.load
parent
3f238a01
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
135 additions
and
70 deletions
+135
-70
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+3
-3
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+2
-2
tensorpack/tfutils/optimizer.py
tensorpack/tfutils/optimizer.py
+57
-0
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+21
-54
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+52
-11
No files found.
tensorpack/models/batch_norm.py
View file @
f60989d3
...
...
@@ -118,7 +118,7 @@ def get_bn_variables(x, use_scale, use_bias):
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
moving_var
=
tf
.
get_variable
(
'variance/EMA'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(),
trainable
=
False
)
return
beta
,
gamma
,
moving_mean
,
moving_var
return
x
,
beta
,
gamma
,
moving_mean
,
moving_var
def
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
):
...
...
@@ -171,7 +171,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
with the official inceptionv3 example).
"""
shape
=
x
.
get_shape
()
.
as_list
()
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
x
,
use_scale
,
use_bias
)
x
,
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
x
,
use_scale
,
use_bias
)
ctx
=
get_current_tower_context
()
if
use_local_stat
is
None
:
...
...
@@ -231,7 +231,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
"""
shape
=
x
.
get_shape
()
.
as_list
()
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
x
,
use_scale
,
use_bias
)
x
,
beta
,
gamma
,
moving_mean
,
moving_var
=
get_bn_variables
(
x
,
use_scale
,
use_bias
)
ctx
=
get_current_tower_context
()
use_local_stat
=
ctx
.
is_training
...
...
tensorpack/tfutils/common.py
View file @
f60989d3
...
...
@@ -60,8 +60,8 @@ def get_global_step_var():
with
tf
.
variable_scope
(
scope
,
reuse
=
False
),
\
tf
.
name_scope
(
None
):
var
=
tf
.
get_variable
(
GLOBAL_STEP_OP_NAME
,
initializer
=
0
,
trainable
=
False
,
dtype
=
tf
.
int
32
)
initializer
=
tf
.
constant
(
0
,
dtype
=
tf
.
int64
)
,
trainable
=
False
,
dtype
=
tf
.
int
64
)
return
var
...
...
tensorpack/tfutils/optimizer.py
View file @
f60989d3
...
...
@@ -6,6 +6,7 @@
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
.gradproc
import
apply_grad_processors
as
apply_gradproc
from
.gradproc
import
FilterNoneGrad
__all__
=
[
'apply_grad_processors'
,
'ProxyOptimizer'
,
'PostProcessOptimizer'
,
'VariableAssignmentOptimizer'
]
...
...
@@ -115,3 +116,59 @@ class VariableAssignmentOptimizer(PostProcessOptimizer):
return
t
return
tf
.
assign
(
v
,
t
,
use_locking
=
False
)
.
op
super
(
VariableAssignmentOptimizer
,
self
)
.
__init__
(
opt
,
f
)
class
AccumGradOptimizer
(
ProxyOptimizer
):
def
__init__
(
self
,
opt
,
niter
):
super
(
AccumGradOptimizer
,
self
)
.
__init__
(
opt
)
self
.
_niter
=
niter
self
.
_name
=
"AccumGrad"
self
.
_counter
=
None
def
_create_accum_slots
(
self
,
var_list
):
slots
=
[]
for
v
in
var_list
:
s
=
self
.
_zeros_slot
(
v
,
"accum"
,
self
.
_name
)
slots
.
append
(
s
)
return
slots
def
apply_gradients
(
self
,
grads_and_vars
,
global_step
=
None
,
name
=
None
):
assert
global_step
is
None
,
\
"AccumGradOptimizer doesn't support the option global_step! "
\
"Please maintain it yourself."
grads_and_vars
=
FilterNoneGrad
()
.
process
(
grads_and_vars
)
vs
=
[]
for
g
,
v
in
grads_and_vars
:
assert
isinstance
(
g
,
tf
.
Tensor
)
and
isinstance
(
v
,
tf
.
Variable
),
\
"AccumGradOptimizer only works for dense update! "
\
"Types of v and g are {} and {}"
.
format
(
type
(
v
),
type
(
g
))
vs
.
append
(
v
)
with
tf
.
control_dependencies
(
None
):
slots
=
self
.
_create_accum_slots
(
vs
)
slots_and_vars
=
[(
s
,
gv
[
1
])
for
s
,
gv
in
zip
(
slots
,
grads_and_vars
)]
# Create the counter on the same device as the first variable.
with
tf
.
variable_scope
(
self
.
_name
),
\
tf
.
colocate_with
(
vs
[
0
]):
counter
=
tf
.
Variable
(
0
,
name
=
"counter"
,
trainable
=
False
,
dtype
=
tf
.
int32
)
ops
=
[]
for
s
,
gv
in
zip
(
slots
,
grads_and_vars
):
g
,
v
=
gv
ops
.
append
(
s
.
assign_add
(
s
,
g
))
update_counter
=
tf
.
assign_add
(
counter
,
1
,
name
=
'update_counter'
)
update_slot_op
=
tf
.
group
(
update_counter
,
*
ops
,
name
=
'update_slot'
)
def
update_grad
():
update_op
=
self
.
_opt
.
apply_gradients
(
slots_and_vars
)
with
tf
.
control_dependencies
([
update_op
]):
clear_ops
=
[
tf
.
assign
(
s
,
0.0
)
for
s
in
slots
]
return
tf
.
group
(
*
clear_ops
,
name
=
'update_grad'
)
pred
=
tf
.
equal
(
tf
.
mod
(
counter
,
self
.
_niter
),
0
)
with
tf
.
control_dependencies
([
update_slot_op
]):
if
name
is
None
:
name
=
'cond_update_grad'
op
=
tf
.
cond
(
pred
,
update_grad
,
lambda
:
tf
.
no_op
(),
name
=
name
)
return
op
tensorpack/tfutils/sessinit.py
View file @
f60989d3
...
...
@@ -4,7 +4,6 @@
import
os
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
import
numpy
as
np
import
tensorflow
as
tf
import
six
...
...
@@ -57,7 +56,7 @@ class NewSession(SessionInit):
class
SaverRestore
(
SessionInit
):
"""
Restore a
n old model saved by
:class:`ModelSaver`.
Restore a
tensorflow checkpoint saved by :class:`tf.train.Saver` or
:class:`ModelSaver`.
"""
def
__init__
(
self
,
model_path
,
prefix
=
None
):
...
...
@@ -73,28 +72,26 @@ class SaverRestore(SessionInit):
def
_init
(
self
,
sess
):
logger
.
info
(
"Restoring checkpoint from {} ..."
.
format
(
self
.
path
))
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
vars_map
=
self
.
_get_vars_to_restore_multimap
(
chkpt_vars
)
for
dic
in
SaverRestore
.
_produce_restore_dict
(
vars_map
):
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
saver
=
tf
.
train
.
Saver
(
var_list
=
dic
,
name
=
str
(
id
(
dic
)),
write_version
=
2
)
saver
.
restore
(
sess
,
self
.
path
)
reader
,
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
graph_vars
=
tf
.
global_variables
()
chkpt_vars_used
=
set
()
@
staticmethod
def
_produce_restore_dict
(
vars_multimap
):
"""
Produce {var_name: var} dict that can be used by `tf.train.Saver`, from a {var_name: [vars]} dict.
"""
while
len
(
vars_multimap
):
ret
=
{}
for
k
in
list
(
vars_multimap
.
keys
()):
v
=
vars_multimap
[
k
]
ret
[
k
]
=
v
[
-
1
]
del
v
[
-
1
]
if
not
len
(
v
):
del
vars_multimap
[
k
]
yield
ret
with
sess
.
as_default
():
for
v
in
graph_vars
:
name
=
get_savename_from_varname
(
v
.
name
,
varname_prefix
=
self
.
prefix
)
if
name
in
chkpt_vars
:
val
=
reader
.
get_tensor
(
name
)
SessionUpdate
.
load_value_to_var
(
v
,
val
)
chkpt_vars_used
.
add
(
name
)
else
:
vname
=
v
.
op
.
name
if
not
is_training_name
(
vname
):
logger
.
warn
(
"Variable {} in the graph not found in checkpoint!"
.
format
(
vname
))
if
len
(
chkpt_vars_used
)
<
len
(
chkpt_vars
):
unused
=
chkpt_vars
-
chkpt_vars_used
for
name
in
sorted
(
unused
):
if
not
is_training_name
(
name
):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
@
staticmethod
def
_read_checkpoint_vars
(
model_path
):
...
...
@@ -105,37 +102,7 @@ class SaverRestore(SessionInit):
if
v
.
startswith
(
PREDICT_TOWER
):
logger
.
error
(
"Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
return
set
(
ckpt_vars
)
def
_get_vars_to_restore_multimap
(
self
,
vars_available
):
"""
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
"""
vars_to_restore
=
tf
.
global_variables
()
var_dict
=
defaultdict
(
list
)
chkpt_vars_used
=
set
()
for
v
in
vars_to_restore
:
name
=
get_savename_from_varname
(
v
.
name
,
varname_prefix
=
self
.
prefix
)
# try to load both 'varname' and 'opname' from checkpoint
# because some old checkpoint might not have ':0'
if
name
in
vars_available
:
var_dict
[
name
]
.
append
(
v
)
chkpt_vars_used
.
add
(
name
)
elif
name
.
endswith
(
':0'
):
name
=
name
[:
-
2
]
if
name
in
vars_available
:
var_dict
[
name
]
.
append
(
v
)
chkpt_vars_used
.
add
(
name
)
else
:
if
not
is_training_name
(
v
.
op
.
name
):
logger
.
warn
(
"Variable {} in the graph not found in checkpoint!"
.
format
(
v
.
op
.
name
))
if
len
(
chkpt_vars_used
)
<
len
(
vars_available
):
unused
=
vars_available
-
chkpt_vars_used
for
name
in
sorted
(
unused
):
if
not
is_training_name
(
name
):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
return
var_dict
return
reader
,
set
(
ckpt_vars
)
class
ParamRestore
(
SessionInit
):
...
...
tensorpack/tfutils/varmanip.py
View file @
f60989d3
...
...
@@ -59,23 +59,64 @@ class SessionUpdate(object):
savename
=
get_savename_from_varname
(
v
.
name
)
self
.
name_map
[
savename
]
.
append
(
v
)
@
staticmethod
def
load_value_to_var
(
var
,
val
,
strict
=
False
):
"""
Call `var.load(val)` with the default session.
Args:
var (tf.Variable):
strict (bool): Behave less strict if set to False.
"""
if
strict
:
var
.
load
(
val
)
return
name
=
var
.
op
.
name
# check incompatible shape
varshape
=
tuple
(
var
.
get_shape
()
.
as_list
())
if
varshape
!=
val
.
shape
:
# TODO only allow reshape when shape different by empty axis
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
val
.
shape
),
\
"{}: {}!={}"
.
format
(
name
,
varshape
,
val
.
shape
)
logger
.
warn
(
"Variable {} is reshaped during assigning"
.
format
(
name
))
val
=
val
.
reshape
(
varshape
)
# fix some common type incompatibility problem, but is certainly not enough
def
upcast
(
vartype
,
valtype
):
# allow up-casting
if
vartype
==
tf
.
float64
and
valtype
==
np
.
float32
:
return
np
.
float64
if
vartype
in
[
tf
.
int64
,
tf
.
int32
]
and
valtype
in
[
np
.
int32
,
np
.
int16
,
np
.
int8
]:
return
np
.
int64
if
vartype
==
tf
.
int64
else
np
.
int32
return
None
if
hasattr
(
val
,
'dtype'
):
vartype
=
var
.
value
()
.
dtype
if
vartype
!=
val
.
dtype
:
msg
=
"Variable {} has dtype {} but was given a value of dtype {}."
.
format
(
name
,
vartype
,
val
.
dtype
)
newtype
=
upcast
(
var
.
dtype
,
val
.
dtype
)
if
newtype
is
not
None
:
val
=
newtype
(
val
)
logger
.
warn
(
msg
+
" Load it after casting!"
)
else
:
assert
vartype
==
val
.
dtype
,
msg
try
:
var
.
load
(
val
)
except
tf
.
errors
.
InvalidArgumentError
:
logger
.
exc
(
"Cannot load this value to the variable {}"
.
format
(
name
))
def
update
(
self
,
prms
):
"""
Args:
prms(dict): dict of {variable name: value}
Any name in prms must be in the graph and in vars_to_update.
"""
for
name
,
value
in
six
.
iteritems
(
prms
):
assert
name
in
self
.
name_map
for
v
in
self
.
name_map
[
name
]:
varshape
=
tuple
(
v
.
get_shape
()
.
as_list
())
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
assert
np
.
prod
(
varshape
)
==
np
.
prod
(
value
.
shape
),
\
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
v
.
load
(
value
,
session
=
self
.
sess
)
with
self
.
sess
.
as_default
():
for
name
,
value
in
six
.
iteritems
(
prms
):
assert
name
in
self
.
name_map
for
v
in
self
.
name_map
[
name
]:
SessionUpdate
.
load_value_to_var
(
v
,
value
)
def
dump_session_params
(
path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment