Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e072d909
Commit
e072d909
authored
Jun 08, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[WIP] reorganize trainer. fix batch_norm
parent
335d6c28
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
204 additions
and
137 deletions
+204
-137
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+7
-14
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+25
-3
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+29
-2
tensorpack/train/base.py
tensorpack/train/base.py
+0
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+143
-116
No files found.
examples/ResNet/cifar10-resnet.py
View file @
e072d909
...
@@ -8,14 +8,9 @@ import tensorflow as tf
...
@@ -8,14 +8,9 @@ import tensorflow as tf
import
argparse
import
argparse
import
os
import
os
from
tensorpack.train
import
TrainConfig
,
QueueInputTrainer
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils
import
*
from
tensorpack.tfutils
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
*
"""
"""
CIFAR10-resnet example.
CIFAR10-resnet example.
...
@@ -186,11 +181,9 @@ if __name__ == '__main__':
...
@@ -186,11 +181,9 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
config
=
get_config
()
with
tf
.
device
(
'/cpu:0'
):
if
args
.
load
:
config
=
get_config
()
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
:
if
args
.
gpu
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
if
args
.
gpu
:
SyncMultiGPUTrainer
(
config
)
.
train
()
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
QueueInputTrainer
(
config
)
.
train
()
tensorpack/models/batch_norm.py
View file @
e072d909
...
@@ -5,7 +5,9 @@
...
@@ -5,7 +5,9 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
copy
import
copy
from
copy
import
copy
import
re
from
..utils
import
logger
from
._common
import
layer_register
from
._common
import
layer_register
__all__
=
[
'BatchNorm'
]
__all__
=
[
'BatchNorm'
]
...
@@ -48,9 +50,28 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -48,9 +50,28 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
else
:
else
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
)
emaname
=
'EMA'
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
if
not
batch_mean
.
name
.
startswith
(
'towerp'
):
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
else
:
assert
not
use_local_stat
# have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_mean
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_var
.
name
)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
G
=
tf
.
get_default_graph
()
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
if
use_local_stat
:
if
use_local_stat
:
with
tf
.
control_dependencies
([
ema_apply_op
]):
with
tf
.
control_dependencies
([
ema_apply_op
]):
...
@@ -58,6 +79,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -58,6 +79,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'bn'
)
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'bn'
)
else
:
else
:
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
# XXX TODO batch==1?
mean
,
var
=
ema_mean
,
ema_var
*
batch
/
(
batch
-
1
)
# unbiased variance estimator
mean
,
var
=
ema_mean
,
ema_var
*
batch
/
(
batch
-
1
)
# unbiased variance estimator
return
tf
.
nn
.
batch_normalization
(
return
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
'bn'
)
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
'bn'
)
tensorpack/tfutils/common.py
View file @
e072d909
...
@@ -5,13 +5,19 @@
...
@@ -5,13 +5,19 @@
from
..utils.naming
import
*
from
..utils.naming
import
*
import
tensorflow
as
tf
import
tensorflow
as
tf
from
copy
import
copy
import
six
from
contextlib
import
contextmanager
__all__
=
[
'get_default_sess_config'
,
__all__
=
[
'get_default_sess_config'
,
'get_global_step'
,
'get_global_step'
,
'get_global_step_var'
,
'get_global_step_var'
,
'get_op_var_name'
,
'get_op_var_name'
,
'get_vars_by_names'
'get_vars_by_names'
,
]
'backup_collection'
,
'restore_collection'
,
'clear_collection'
,
'freeze_collection'
]
def
get_default_sess_config
(
mem_fraction
=
0.9
):
def
get_default_sess_config
(
mem_fraction
=
0.9
):
"""
"""
...
@@ -66,3 +72,24 @@ def get_vars_by_names(names):
...
@@ -66,3 +72,24 @@ def get_vars_by_names(names):
opn
,
varn
=
get_op_var_name
(
n
)
opn
,
varn
=
get_op_var_name
(
n
)
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
return
ret
return
ret
def
backup_collection
(
keys
):
ret
=
{}
for
k
in
keys
:
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
return
ret
def
restore_collection
(
backup
):
for
k
,
v
in
six
.
iteritems
(
backup
):
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
v
)
def
clear_collection
(
keys
):
for
k
in
keys
:
del
tf
.
get_collection_ref
(
k
)[:]
@
contextmanager
def
freeze_collection
(
keys
):
backup
=
backup_collection
(
keys
)
yield
restore_collection
(
backup
)
tensorpack/train/base.py
View file @
e072d909
...
@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
...
@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
from
..callbacks
import
StatHolder
from
..callbacks
import
StatHolder
from
..tfutils
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
create_summary
from
..tfutils.summary
import
create_summary
from
..tfutils.modelutils
import
describe_model
__all__
=
[
'Trainer'
]
__all__
=
[
'Trainer'
]
...
@@ -141,7 +140,6 @@ class Trainer(object):
...
@@ -141,7 +140,6 @@ class Trainer(object):
self
.
sess
.
close
()
self
.
sess
.
close
()
def
init_session_and_coord
(
self
):
def
init_session_and_coord
(
self
):
describe_model
()
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
coord
=
tf
.
train
.
Coordinator
()
self
.
coord
=
tf
.
train
.
Coordinator
()
...
...
tensorpack/train/trainer.py
View file @
e072d909
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment