Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7a110067
Commit
7a110067
authored
Aug 30, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update bn to use towercontext
parent
7a0e8747
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
21 deletions
+30
-21
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+15
-21
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+15
-0
No files found.
tensorpack/models/batch_norm.py
View file @
7a110067
...
...
@@ -18,7 +18,7 @@ __all__ = ['BatchNorm']
# decay: being too close to 1 leads to slow start-up. torch use 0.9.
# eps: torch: 1e-5. Lasagne: 1e-4
@
layer_register
(
log_shape
=
False
)
def
BatchNorm
(
x
,
use_local_stat
=
Tru
e
,
decay
=
0.9
,
epsilon
=
1e-5
):
def
BatchNorm
(
x
,
use_local_stat
=
Non
e
,
decay
=
0.9
,
epsilon
=
1e-5
):
"""
Batch normalization layer as described in:
...
...
@@ -30,8 +30,9 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
* Whole-population mean/variance is calculated by a running-average mean/variance.
* Epsilon for variance is set to 1e-5, as is `torch/nn <https://github.com/torch/nn/blob/master/BatchNormalization.lua>`_.
:param input: a NHWC tensor or a NC vector
:param use_local_stat: bool. whether to use mean/var of this batch or the moving average. Set to True in training and False in testing
:param input: a NHWC or NC tensor
:param use_local_stat: bool. whether to use mean/var of this batch or the moving average.
Default to True in training and False in predicting.
:param decay: decay rate. default to 0.999.
:param epsilon: default to 1e-5.
"""
...
...
@@ -53,41 +54,34 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
batch_mean
=
tf
.
identity
(
batch_mean
,
'mean'
)
batch_var
=
tf
.
identity
(
batch_var
,
'variance'
)
# XXX a hack to handle training tower & prediction tower together....
emaname
=
'EMA'
#ctx = get_current_model_context()
if
not
batch_mean
.
name
.
startswith
(
'towerp'
):
ctx
=
get_current_model_context
()
if
use_local_stat
is
None
:
use_local_stat
=
ctx
.
is_training
assert
use_local_stat
==
ctx
.
is_training
if
ctx
.
is_training
:
# training tower
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
if
not
batch_mean
.
name
.
startswith
(
'tower'
)
or
\
batch_mean
.
name
.
startswith
(
'tower0'
):
if
ctx
.
is_main_training_tower
:
# inside main training tower
tf
.
add_to_collection
(
EXTRA_SAVE_VARS_KEY
,
ema_mean
)
tf
.
add_to_collection
(
EXTRA_SAVE_VARS_KEY
,
ema_var
)
else
:
# use training-statistics in prediction
assert
not
use_local_stat
with
tf
.
name_scope
(
None
):
# figure out the var name
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
mean_var_name
=
ema
.
average_name
(
batch_mean
)
+
':0'
var_var_name
=
ema
.
average_name
(
batch_var
)
+
':0'
# use statistics in another tower
G
=
tf
.
get_default_graph
()
# find training statistics in training tower
try
:
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
mean_var_name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
var_var_name
)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
except
KeyError
:
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
mean_var_name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
var_var_name
)
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
...
...
tensorpack/models/model_desc.py
View file @
7a110067
...
...
@@ -26,6 +26,10 @@ class TowerContext(object):
is_training
=
not
self
.
_name
.
startswith
(
'towerp'
)
self
.
_is_training
=
is_training
@
property
def
is_main_training_tower
(
self
):
return
self
.
is_training
and
(
self
.
_name
==
''
or
self
.
_name
==
'tower0'
)
@
property
def
is_main_tower
(
self
):
return
self
.
_name
==
''
or
self
.
_name
==
'tower0'
...
...
@@ -34,6 +38,17 @@ class TowerContext(object):
def
is_training
(
self
):
return
self
.
_is_training
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
if
self
.
is_main_tower
:
return
graph
.
get_tensor_by_name
(
name
)
if
name
.
startswith
(
'towerp'
):
newname
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
name
)
try
:
return
graph
.
get_tensor_by_name
(
newname
)
except
KeyError
:
newname
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
name
)
return
graph
.
get_tensor_by_name
(
newname
)
def
__enter__
(
self
):
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment