Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6e3e0115
Commit
6e3e0115
authored
Nov 27, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix bn under latest TF
parent
5cccf2b8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
30 deletions
+35
-30
tensorpack/dataflow/dataset/ptb.py
tensorpack/dataflow/dataset/ptb.py
+0
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+34
-29
tensorpack/train/base.py
tensorpack/train/base.py
+1
-0
No files found.
tensorpack/dataflow/dataset/ptb.py
View file @
6e3e0115
...
@@ -45,7 +45,6 @@ class PennTreeBank(RNGDataFlow):
...
@@ -45,7 +45,6 @@ class PennTreeBank(RNGDataFlow):
super
(
PennTreeBank
,
self
)
.
__init__
()
super
(
PennTreeBank
,
self
)
.
__init__
()
if
data_dir
is
None
:
if
data_dir
is
None
:
data_dir
=
get_dataset_path
(
'ptb_data'
)
data_dir
=
get_dataset_path
(
'ptb_data'
)
assert
os
.
path
.
isdir
(
data_dir
)
data3
,
word_to_id
=
get_raw_data
(
data_dir
)
data3
,
word_to_id
=
get_raw_data
(
data_dir
)
self
.
word_to_id
=
word_to_id
self
.
word_to_id
=
word_to_id
self
.
data
=
np
.
asarray
(
self
.
data
=
np
.
asarray
(
...
...
tensorpack/models/batch_norm.py
View file @
6e3e0115
...
@@ -40,7 +40,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -40,7 +40,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
zeros_initializer
)
initializer
=
tf
.
zeros_initializer
)
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
ones_initializer
)
initializer
=
tf
.
constant_initializer
(
1.0
)
)
if
len
(
shape
)
==
2
:
if
len
(
shape
)
==
2
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
...
@@ -59,39 +59,44 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -59,39 +59,44 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if
use_local_stat
:
if
use_local_stat
:
# training tower
# training tower
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
if
ctx
.
is_training
:
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
reuse
=
tf
.
get_variable_scope
()
.
reuse
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
# TODO if reuse=True, try to find and use the existing statistics
if
ctx
.
is_main_training_tower
:
# how to use multiple tensors to update one EMA? seems impossbile
# inside main training tower
add_model_variable
(
ema_mean
)
add_model_variable
(
ema_var
)
else
:
if
ctx
.
is_main_tower
:
# not training, but main tower. need to create the vars
with
tf
.
name_scope
(
None
):
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
else
:
if
ctx
.
is_main_training_tower
:
# use statistics in another tower
# inside main training tower
G
=
tf
.
get_default_graph
()
add_model_variable
(
ema_mean
)
# figure out the var name
add_model_variable
(
ema_var
)
with
tf
.
name_scope
(
None
):
else
:
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
# no apply() is called here, no magic vars will get created
mean_var_name
=
ema
.
average_name
(
batch_mean
)
+
':0'
assert
not
ctx
.
is_training
var_var_name
=
ema
.
average_name
(
batch_var
)
+
':0'
with
tf
.
name_scope
(
None
):
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
)
mean_var_name
=
ema
.
average_name
(
batch_mean
)
#logger.info("In prediction, using {} instead of {} for {}".format(
var_var_name
=
ema
.
average_name
(
batch_var
)
#mean_name, ema_mean.name, batch_mean.name))
sc
=
tf
.
get_variable_scope
()
if
ctx
.
is_main_tower
:
# main tower, but needs to use global stat. global stat must be from outside
# TODO when reuse=True, the variable name could actually be different
ema_mean
=
tf
.
get_variable
(
'mean/'
+
emaname
,
[
n_out
])
ema_var
=
tf
.
get_variable
(
'variance/'
+
emaname
,
[
n_out
])
else
:
## use statistics in another tower
G
=
tf
.
get_default_graph
()
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
)
if
use_local_stat
:
if
use_local_stat
:
with
tf
.
control_dependencies
([
ema_apply_op
]):
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
mul
=
tf
.
select
(
tf
.
equal
(
batch
,
1.0
),
1.0
,
batch
/
(
batch
-
1
))
mul
=
tf
.
select
(
tf
.
equal
(
batch
,
1.0
),
1.0
,
batch
/
(
batch
-
1
))
batch_var
=
batch_var
*
mul
# use unbiased variance estimator in training
batch_var
=
batch_var
*
mul
# use unbiased variance estimator in training
with
tf
.
control_dependencies
([
ema_apply_op
]
if
ctx
.
is_training
else
[]):
# only apply EMA op if is_training
return
tf
.
nn
.
batch_normalization
(
return
tf
.
nn
.
batch_normalization
(
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'output'
)
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'output'
)
else
:
else
:
...
...
tensorpack/train/base.py
View file @
6e3e0115
...
@@ -110,6 +110,7 @@ class Trainer(object):
...
@@ -110,6 +110,7 @@ class Trainer(object):
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
logger
.
info
(
"Initializing graph variables ..."
)
logger
.
info
(
"Initializing graph variables ..."
)
# TODO newsession + sessinit?
self
.
sess
.
run
(
tf
.
initialize_all_variables
())
self
.
sess
.
run
(
tf
.
initialize_all_variables
())
self
.
config
.
session_init
.
init
(
self
.
sess
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment