Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6dc04278
Commit
6dc04278
authored
Nov 27, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix BN again and mute some compatibility noise
parent
6e3e0115
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
18 deletions
+37
-18
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+6
-1
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+6
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+20
-15
tensorpack/train/base.py
tensorpack/train/base.py
+5
-1
No files found.
tensorpack/callbacks/common.py
View file @
6dc04278
...
@@ -18,13 +18,18 @@ class ModelSaver(Callback):
...
@@ -18,13 +18,18 @@ class ModelSaver(Callback):
Save the model to logger directory.
Save the model to logger directory.
"""
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
var_collections
=
tf
.
GraphKeys
()
.
VARIABLES
):
var_collections
=
None
):
"""
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
"""
self
.
keep_recent
=
keep_recent
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
self
.
keep_freq
=
keep_freq
if
var_collections
is
None
:
try
:
var_collections
=
tf
.
GraphKeys
.
GLOBAL_VARIABLES
except
:
var_collections
=
tf
.
GraphKeys
.
VARIABLES
if
not
isinstance
(
var_collections
,
list
):
if
not
isinstance
(
var_collections
,
list
):
var_collections
=
[
var_collections
]
var_collections
=
[
var_collections
]
self
.
var_collections
=
var_collections
self
.
var_collections
=
var_collections
...
...
tensorpack/callbacks/param.py
View file @
6dc04278
...
@@ -44,7 +44,12 @@ class GraphVarParam(HyperParam):
...
@@ -44,7 +44,12 @@ class GraphVarParam(HyperParam):
self
.
_readable_name
,
self
.
var_name
=
get_op_var_name
(
name
)
self
.
_readable_name
,
self
.
var_name
=
get_op_var_name
(
name
)
def
setup_graph
(
self
):
def
setup_graph
(
self
):
all_vars
=
tf
.
all_variables
()
try
:
all_vars
=
tf
.
global_variables
()
except
:
# TODO
all_vars
=
tf
.
all_variables
()
for
v
in
all_vars
:
for
v
in
all_vars
:
if
v
.
name
==
self
.
var_name
:
if
v
.
name
==
self
.
var_name
:
self
.
var
=
v
self
.
var
=
v
...
...
tensorpack/models/batch_norm.py
View file @
6dc04278
...
@@ -60,19 +60,22 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -60,19 +60,22 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if
use_local_stat
:
if
use_local_stat
:
# training tower
# training tower
if
ctx
.
is_training
:
if
ctx
.
is_training
:
reuse
=
tf
.
get_variable_scope
()
.
reuse
#reuse = tf.get_variable_scope().reuse
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
False
):
# TODO if reuse=True, try to find and use the existing statistics
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
# how to use multiple tensors to update one EMA? seems impossbile
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
# TODO if reuse=True, try to find and use the existing statistics
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
# how to use multiple tensors to update one EMA? seems impossbile
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
if
ctx
.
is_main_training_tower
:
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
# inside main training tower
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
add_model_variable
(
ema_mean
)
if
ctx
.
is_main_training_tower
:
add_model_variable
(
ema_var
)
# inside main training tower
add_model_variable
(
ema_mean
)
add_model_variable
(
ema_var
)
else
:
else
:
# no apply() is called here, no magic vars will get created
# no apply() is called here, no magic vars will get created,
# no reuse issue will happen
assert
not
ctx
.
is_training
assert
not
ctx
.
is_training
with
tf
.
name_scope
(
None
):
with
tf
.
name_scope
(
None
):
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
...
@@ -81,14 +84,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -81,14 +84,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
sc
=
tf
.
get_variable_scope
()
sc
=
tf
.
get_variable_scope
()
if
ctx
.
is_main_tower
:
if
ctx
.
is_main_tower
:
# main tower, but needs to use global stat. global stat must be from outside
# main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the variable name could actually be different
# TODO when reuse=True, the desired variable name could
# actually be different, because a different var is created
# for different reuse tower
ema_mean
=
tf
.
get_variable
(
'mean/'
+
emaname
,
[
n_out
])
ema_mean
=
tf
.
get_variable
(
'mean/'
+
emaname
,
[
n_out
])
ema_var
=
tf
.
get_variable
(
'variance/'
+
emaname
,
[
n_out
])
ema_var
=
tf
.
get_variable
(
'variance/'
+
emaname
,
[
n_out
])
else
:
else
:
## use statistics in another tower
## use statistics in another tower
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
)
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
+
':0'
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
+
':0'
)
if
use_local_stat
:
if
use_local_stat
:
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
...
...
tensorpack/train/base.py
View file @
6dc04278
...
@@ -111,7 +111,11 @@ class Trainer(object):
...
@@ -111,7 +111,11 @@ class Trainer(object):
logger
.
info
(
"Initializing graph variables ..."
)
logger
.
info
(
"Initializing graph variables ..."
)
# TODO newsession + sessinit?
# TODO newsession + sessinit?
self
.
sess
.
run
(
tf
.
initialize_all_variables
())
try
:
initop
=
tf
.
global_variables_initializer
()
except
:
initop
=
tf
.
initialize_all_variables
()
self
.
sess
.
run
(
initop
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
tf
.
get_default_graph
()
.
finalize
()
tf
.
get_default_graph
()
.
finalize
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment