Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e2cc7058
Commit
e2cc7058
authored
Jun 01, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove deprecated impl in BN. catch more exception in saver
parent
41122718
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
82 deletions
+4
-82
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+3
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+1
-80
No files found.
tensorpack/callbacks/saver.py
View file @
e2cc7058
...
...
@@ -73,8 +73,9 @@ class ModelSaver(Callback):
global_step
=
tf
.
train
.
get_global_step
(),
write_meta_graph
=
False
)
logger
.
info
(
"Model saved to
%
s."
%
tf
.
train
.
get_checkpoint_state
(
self
.
checkpoint_dir
)
.
model_checkpoint_path
)
except
(
OSError
,
IOError
):
# disk error sometimes.. just ignore it
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
except
(
OSError
,
IOError
,
tf
.
errors
.
PermissionDeniedError
,
tf
.
errors
.
ResourceExhaustedError
):
# disk error sometimes.. just ignore it
logger
.
exception
(
"Exception in ModelSaver!"
)
class
MinSaver
(
Callback
):
...
...
tensorpack/models/batch_norm.py
View file @
e2cc7058
...
...
@@ -17,85 +17,6 @@ __all__ = ['BatchNorm', 'BatchRenorm']
# eps: torch: 1e-5. Lasagne: 1e-4
# XXX This is deprecated. Only kept for future reference.
@
layer_register
(
log_shape
=
False
)
def
BatchNormV1
(
x
,
use_local_stat
=
None
,
decay
=
0.9
,
epsilon
=
1e-5
):
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
in
[
2
,
4
]
n_out
=
shape
[
-
1
]
# channel
assert
n_out
is
not
None
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
))
if
len
(
shape
)
==
2
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
else
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
# just to make a clear name.
batch_mean
=
tf
.
identity
(
batch_mean
,
'mean'
)
batch_var
=
tf
.
identity
(
batch_var
,
'variance'
)
emaname
=
'EMA'
ctx
=
get_current_tower_context
()
if
use_local_stat
is
None
:
use_local_stat
=
ctx
.
is_training
if
use_local_stat
!=
ctx
.
is_training
:
logger
.
warn
(
"[BatchNorm] use_local_stat != is_training"
)
if
use_local_stat
:
# training tower
if
ctx
.
is_training
:
# reuse = tf.get_variable_scope().reuse
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
False
):
# BatchNorm in reuse scope can be tricky! Moving mean/variance are not reused
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
# if reuse=True, try to find and use the existing statistics
# how to use multiple tensors to update one EMA? seems impossbile
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
if
ctx
.
is_main_training_tower
:
# inside main training tower
add_model_variable
(
ema_mean
)
add_model_variable
(
ema_var
)
else
:
# no apply() is called here, no magic vars will get created,
# no reuse issue will happen
assert
not
ctx
.
is_training
with
tf
.
name_scope
(
None
):
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
mean_var_name
=
ema
.
average_name
(
batch_mean
)
var_var_name
=
ema
.
average_name
(
batch_var
)
if
ctx
.
is_main_tower
:
# main tower, but needs to use global stat. global stat must be from outside
# when reuse=True, the desired variable name could
# actually be different, because a different var is created
# for different reuse tower
ema_mean
=
tf
.
get_variable
(
'mean/'
+
emaname
,
[
n_out
])
ema_var
=
tf
.
get_variable
(
'variance/'
+
emaname
,
[
n_out
])
else
:
# use statistics in another tower
G
=
tf
.
get_default_graph
()
ema_mean
=
ctx
.
find_tensor_in_main_tower
(
G
,
mean_var_name
+
':0'
)
ema_var
=
ctx
.
find_tensor_in_main_tower
(
G
,
var_var_name
+
':0'
)
if
use_local_stat
:
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
mul
=
tf
.
where
(
tf
.
equal
(
batch
,
1.0
),
1.0
,
batch
/
(
batch
-
1
))
batch_var
=
batch_var
*
mul
# use unbiased variance estimator in training
with
tf
.
control_dependencies
([
ema_apply_op
]
if
ctx
.
is_training
else
[]):
# only apply EMA op if is_training
return
tf
.
nn
.
batch_normalization
(
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'output'
)
else
:
return
tf
.
nn
.
batch_normalization
(
x
,
ema_mean
,
ema_var
,
beta
,
gamma
,
epsilon
,
'output'
)
def
get_bn_variables
(
n_out
,
use_scale
,
use_bias
,
gamma_init
):
if
use_bias
:
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
())
...
...
@@ -310,7 +231,7 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
xn
=
tf
.
nn
.
batch_normalization
(
x
,
moving_mean
,
moving_var
,
beta
,
gamma
,
epsilon
)
if
ctx
.
is_main_training_tower
:
if
ctx
.
is_main_training_tower
or
ctx
.
has_own_variables
:
ret
=
update_bn_ema
(
xn
,
batch_mean
,
batch_var
,
moving_mean
,
moving_var
,
decay
)
else
:
ret
=
tf
.
identity
(
xn
,
name
=
'output'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment