Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e072d909
Commit
e072d909
authored
Jun 08, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[WIP] reorganize trainer. fix batch_norm
parent
335d6c28
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
204 additions
and
137 deletions
+204
-137
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+7
-14
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+25
-3
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+29
-2
tensorpack/train/base.py
tensorpack/train/base.py
+0
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+143
-116
No files found.
examples/ResNet/cifar10-resnet.py
View file @
e072d909
...
...
@@ -8,14 +8,9 @@ import tensorflow as tf
import
argparse
import
os
from
tensorpack.train
import
TrainConfig
,
QueueInputTrainer
from
tensorpack.models
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils
import
*
from
tensorpack.tfutils
import
*
from
tensorpack
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
*
"""
CIFAR10-resnet example.
...
...
@@ -186,11 +181,9 @@ if __name__ == '__main__':
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
with
tf
.
device
(
'/cpu:0'
):
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
QueueInput
Trainer
(
config
)
.
train
()
SyncMultiGPU
Trainer
(
config
)
.
train
()
tensorpack/models/batch_norm.py
View file @
e072d909
...
...
@@ -5,7 +5,9 @@
import
tensorflow
as
tf
from
copy
import
copy
import
re
from
..utils
import
logger
from
._common
import
layer_register
__all__
=
[
'BatchNorm'
]
...
...
@@ -48,9 +50,28 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
else
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
)
emaname
=
'EMA'
if
not
batch_mean
.
name
.
startswith
(
'towerp'
):
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
else
:
assert
not
use_local_stat
# have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_mean
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_var
.
name
)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
G
=
tf
.
get_default_graph
()
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
if
use_local_stat
:
with
tf
.
control_dependencies
([
ema_apply_op
]):
...
...
@@ -58,6 +79,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'bn'
)
else
:
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
# XXX TODO batch==1?
mean
,
var
=
ema_mean
,
ema_var
*
batch
/
(
batch
-
1
)
# unbiased variance estimator
return
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
'bn'
)
tensorpack/tfutils/common.py
View file @
e072d909
...
...
@@ -5,13 +5,19 @@
from
..utils.naming
import
*
import
tensorflow
as
tf
from
copy
import
copy
import
six
from
contextlib
import
contextmanager
__all__
=
[
'get_default_sess_config'
,
'get_global_step'
,
'get_global_step_var'
,
'get_op_var_name'
,
'get_vars_by_names'
]
'get_vars_by_names'
,
'backup_collection'
,
'restore_collection'
,
'clear_collection'
,
'freeze_collection'
]
def
get_default_sess_config
(
mem_fraction
=
0.9
):
"""
...
...
@@ -66,3 +72,24 @@ def get_vars_by_names(names):
opn
,
varn
=
get_op_var_name
(
n
)
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
return
ret
def
backup_collection
(
keys
):
ret
=
{}
for
k
in
keys
:
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
return
ret
def
restore_collection
(
backup
):
for
k
,
v
in
six
.
iteritems
(
backup
):
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
v
)
def
clear_collection
(
keys
):
for
k
in
keys
:
del
tf
.
get_collection_ref
(
k
)[:]
@
contextmanager
def
freeze_collection
(
keys
):
backup
=
backup_collection
(
keys
)
yield
restore_collection
(
backup
)
tensorpack/train/base.py
View file @
e072d909
...
...
@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
from
..callbacks
import
StatHolder
from
..tfutils
import
*
from
..tfutils.summary
import
create_summary
from
..tfutils.modelutils
import
describe_model
__all__
=
[
'Trainer'
]
...
...
@@ -141,7 +140,6 @@ class Trainer(object):
self
.
sess
.
close
()
def
init_session_and_coord
(
self
):
describe_model
()
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
coord
=
tf
.
train
.
Coordinator
()
...
...
tensorpack/train/trainer.py
View file @
e072d909
...
...
@@ -5,7 +5,6 @@
import
tensorflow
as
tf
import
threading
import
time
import
copy
import
re
import
functools
from
six.moves
import
zip
...
...
@@ -15,6 +14,7 @@ from ..dataflow.common import RepeatedData
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..tfutils
import
*
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
...
...
@@ -42,6 +42,7 @@ class SimpleTrainer(Trainer):
avg_maintain_op
)
self
.
init_session_and_coord
()
describe_model
()
# create an infinte data producer
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
main_loop
()
...
...
@@ -100,14 +101,11 @@ class EnqueueThread(threading.Thread):
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInputTrainer
(
Trainer
):
"""
Trainer which builds a FIFO queue for input.
Support multi GPU.
"""
""" Single GPU Trainer, takes input from a queue"""
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
def
__init__
(
self
,
config
,
input_queue
=
None
,
async
=
False
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
...
...
@@ -120,27 +118,11 @@ class QueueInputTrainer(Trainer):
100
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
else
:
self
.
input_queue
=
input_queue
self
.
async
=
async
if
self
.
async
:
assert
self
.
config
.
nr_tower
>
1
self
.
dequed_inputs
=
[]
if
predict_tower
is
None
:
# by default,
only
use first training tower for prediction
# by default, use first training tower for prediction
predict_tower
=
[
0
]
self
.
predict_tower
=
predict_tower
@
staticmethod
def
_average_grads
(
tower_grads
):
ret
=
[]
for
grad_and_vars
in
zip
(
*
tower_grads
):
v
=
grad_and_vars
[
0
][
1
]
try
:
grad
=
tf
.
add_n
([
x
[
0
]
for
x
in
grad_and_vars
])
/
float
(
len
(
tower_grads
))
except
AssertionError
:
logger
.
error
(
"Error while processing gradients of {}"
.
format
(
v
.
name
))
raise
ret
.
append
((
grad
,
v
))
return
ret
self
.
dequed_inputs
=
None
def
_get_model_inputs
(
self
):
""" Dequeue a datapoint from input_queue and return"""
...
...
@@ -150,42 +132,111 @@ class QueueInputTrainer(Trainer):
assert
len
(
ret
)
==
len
(
self
.
input_vars
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
self
.
dequed_inputs
.
append
(
ret
)
return
ret
def
_build_predict_tower
(
self
):
inputs
=
self
.
model
.
get_input_vars
()
tf
.
get_variable_scope
()
.
reuse_variables
()
for
k
in
self
.
predict_tower
:
logger
.
info
(
"Building graph for predict tower
0
{}..."
.
format
(
k
))
logger
.
info
(
"Building graph for predict tower
p
{}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)),
\
tf
.
name_scope
(
'tower
0
{}'
.
format
(
k
)):
tf
.
name_scope
(
'tower
p
{}'
.
format
(
k
)):
self
.
model
.
build_graph
(
inputs
,
False
)
tf
.
get_variable_scope
()
.
reuse_variables
()
def
_single_tower_grad
(
self
):
""" Get grad and cost for single-tower case"""
model_inputs
=
self
.
_get_model_inputs
()
self
.
dequed_inputs
=
model_inputs
=
self
.
_get_model_inputs
()
self
.
model
.
build_graph
(
model_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
return
grads
def
_build_enque_thread
(
self
):
# create a thread that keeps filling the queue
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
self
.
input_th
=
EnqueueThread
(
self
,
self
.
input_queue
,
enqueue_op
,
self
.
input_vars
)
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
def
train
(
self
):
assert
self
.
config
.
nr_tower
==
1
,
"QueueInputTrainer only supports 1 tower!"
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
process_grads
(
grads
)
describe_model
()
with
freeze_collection
(
self
.
SUMMARY_BACKUP_KEYS
):
self
.
_build_predict_tower
()
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
())
self
.
main_loop
()
def
run_step
(
self
):
""" just run self.train_op"""
self
.
sess
.
run
([
self
.
train_op
])
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
if
self
.
summary_op
is
not
None
:
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
:param tower: return the kth predict_func
:returns: a predictor function
"""
tower
=
self
.
predict_tower
[
tower
%
len
(
self
.
predict_tower
)]
raw_input_vars
=
get_vars_by_names
(
input_names
)
output_names
=
[
'towerp{}/'
.
format
(
tower
)
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
raw_input_vars
)
feed
=
dict
(
zip
(
raw_input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
""" return n predicts functions evenly on each predict_tower"""
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
class
MultiGPUTrainer
(
QueueInputTrainer
):
""" Base class for multi-gpu training"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
super
(
MultiGPUTrainer
,
self
)
.
__init__
(
config
,
input_queue
,
predict_tower
)
assert
config
.
nr_tower
>
1
self
.
dequed_inputs
=
[]
@
staticmethod
def
_average_grads
(
tower_grads
):
ret
=
[]
for
grad_and_vars
in
zip
(
*
tower_grads
):
v
=
grad_and_vars
[
0
][
1
]
try
:
grad
=
tf
.
add_n
([
x
[
0
]
for
x
in
grad_and_vars
])
/
float
(
len
(
tower_grads
))
except
AssertionError
:
logger
.
error
(
"Error while processing gradients of {}"
.
format
(
v
.
name
))
raise
ret
.
append
((
grad
,
v
))
return
ret
def
_multi_tower_grads
(
self
):
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
# to avoid repeated summary from each device
collect_dedup
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
kept_summaries
=
{}
for
k
in
collect_dedup
:
del
tf
.
get_collection_ref
(
k
)[:]
grad_list
=
[]
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
i
))
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
self
.
dequed_inputs
.
append
(
model_inputs
)
self
.
model
.
build_graph
(
model_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
# build tower
...
...
@@ -196,23 +247,38 @@ class QueueInputTrainer(Trainer):
if
i
==
0
:
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
tf
.
get_variable_scope
()
.
reuse_variables
()
for
k
in
collect_dedup
:
kept_summaries
[
k
]
=
copy
.
copy
(
tf
.
get_collection
(
k
))
for
k
in
collect_dedup
:
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
kept_summaries
[
k
])
# avoid repeated summary from each device
backup
=
backup_collection
(
self
.
SUMMARY_BACKUP_KEYS
)
restore_collection
(
backup
)
return
grad_list
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
self
.
_build_predict_tower
()
if
self
.
config
.
nr_tower
>
1
:
grad_list
=
self
.
_multi_tower_grads
()
if
not
self
.
async
:
grads
=
QueueInput
Trainer
.
_average_grads
(
grad_list
)
grads
=
MultiGPU
Trainer
.
_average_grads
(
grad_list
)
grads
=
self
.
process_grads
(
grads
)
else
:
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
())
describe_model
()
self
.
_build_predict_tower
()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self
.
main_loop
()
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def
scale
(
grads
):
...
...
@@ -220,15 +286,12 @@ class QueueInputTrainer(Trainer):
grad_list
=
map
(
scale
,
grad_list
)
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
grads
=
grad_list
[
0
]
# use grad from the first tower for the main iteration
else
:
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
process_grads
(
grads
)
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
())
describe_model
()
if
self
.
async
:
# prepare train_op for the rest of the towers
self
.
threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
...
...
@@ -240,17 +303,13 @@ class QueueInputTrainer(Trainer):
self
.
threads
.
append
(
th
)
self
.
async_running
=
False
self
.
_build_predict_tower
()
self
.
init_session_and_coord
()
# create a thread that keeps filling the queue
self
.
input_th
=
EnqueueThread
(
self
,
self
.
input_queue
,
enqueue_op
,
self
.
input_vars
)
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
# do nothing in training
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self
.
main_loop
()
def
run_step
(
self
):
if
self
.
async
:
if
not
self
.
async_running
:
self
.
async_running
=
True
for
th
in
self
.
threads
:
# resume all threads
...
...
@@ -258,41 +317,9 @@ class QueueInputTrainer(Trainer):
self
.
sess
.
run
([
self
.
train_op
])
# faster since train_op return None
def
_trigger_epoch
(
self
):
# note that summary_op will take a data from the queue
if
self
.
async
:
self
.
async_running
=
False
for
th
in
self
.
threads
:
th
.
pause
()
if
self
.
summary_op
is
not
None
:
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
:param tower: return the kth predict_func
"""
tower
=
self
.
predict_tower
[
tower
%
len
(
self
.
predict_tower
)]
if
self
.
config
.
nr_tower
>
1
:
logger
.
info
(
"Prepare a predictor function for tower0{} ..."
.
format
(
tower
))
raw_input_vars
=
get_vars_by_names
(
input_names
)
if
self
.
config
.
nr_tower
>
1
:
output_names
=
[
'tower0{}/'
.
format
(
tower
)
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
raw_input_vars
)
feed
=
dict
(
zip
(
raw_input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
def
AsyncMultiGPUTrainer
(
config
):
return
QueueInputTrainer
(
config
,
async
=
True
)
def
SyncMultiGPUTrainer
(
config
):
return
QueueInputTrainer
(
config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment