Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c723c5a4
Commit
c723c5a4
authored
Jun 09, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve AsyncMultiGPUTrainer
parent
5cfbff39
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
67 deletions
+22
-67
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+22
-67
No files found.
tensorpack/train/multigpu.py
View file @
c723c5a4
...
...
@@ -4,14 +4,11 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
itertools
import
operator
import
re
from
six.moves
import
zip
,
range
from
..utils
import
logger
from
..utils.naming
import
TOWER_FREEZE_KEYS
from
..utils.concurrency
import
LoopThread
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.tower
import
TowerContext
from
..tfutils.collection
import
backup_collection
,
restore_collection
...
...
@@ -152,7 +149,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
def
__init__
(
self
,
config
,
ps_device
=
'gpu'
):
"""
Args:
config
: same as in :class:`QueueInputTrainer`.
config
(TrainConfig):
ps_device: either 'gpu' or 'cpu', where variables are stored.
"""
apply_prefetch_policy
(
config
)
...
...
@@ -293,85 +290,43 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
return
tf
.
group
(
*
post_init_ops
,
name
=
'sync_variables_from_tower0'
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainerBase
,
SingleCostFeedfreeTrainer
):
class
AsyncMultiGPUTrainer
(
MultiGPUTrainerBase
,
SingleCostFeedfreeTrainer
):
"""
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without
locking
.
asynchronously updates the model without
averaging the gradient
.
"""
def
__init__
(
self
,
config
,
scale_gradient
=
True
):
def
__init__
(
self
,
config
,
scale_gradient
=
True
):
"""
Args:
config: same as in :class:`QueueInputTrainer`.
scale_gradient (bool): if True, will scale each gradient by
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
config(TrainConfig):
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
apply_prefetch_policy
(
config
,
use_stage
=
False
)
logger
.
warn
(
"Async training hasn't been well optimized. Sync training is even faster"
)
apply_prefetch_policy
(
config
)
self
.
_input_source
=
config
.
data
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_scale_gradient
=
scale_gradient
if
len
(
config
.
tower
)
>
1
:
assert
tf
.
test
.
is_gpu_available
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
def
_setup
(
self
):
super
(
AsyncMultiGPUTrainer
,
self
)
.
_setup
()
raw_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
self
.
config
.
tower
]
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
grad_list
=
MultiGPUTrainerBase
.
build_on_multi_tower
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
devices
)
grad_list
=
[
FilterNoneGrad
()
.
process
(
gv
)
for
gv
in
grad_list
]
if
self
.
_scale_gradient
and
self
.
config
.
nr_tower
>
1
:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradproc
=
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
log
=
False
)
gradproc
=
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
verbose
=
False
)
grad_list
=
[
gradproc
.
process
(
gv
)
for
gv
in
grad_list
]
# Ngpu x Nvar x 2
# use grad from the first tower for iteration in main thread
self
.
_opt
=
self
.
model
.
get_optimizer
()
self
.
train_op
=
self
.
_opt
.
apply_gradients
(
grad_list
[
0
],
name
=
'min_op'
)
self
.
_start_async_threads
(
grad_list
)
def
_start_async_threads
(
self
,
grad_list
):
# prepare train_op for the rest of the towers
# itertools.count is atomic w.r.t. python threads
self
.
async_step_counter
=
itertools
.
count
()
self
.
training_threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
_opt
.
apply_gradients
(
grad_list
[
k
])
def
f
(
op
=
train_op
):
# avoid late-binding
self
.
sess
.
run
([
op
])
# TODO this won't work with StageInput
next
(
self
.
async_step_counter
)
# atomic due to GIL
th
=
LoopThread
(
f
)
th
.
name
=
"AsyncLoopThread-{}"
.
format
(
k
)
th
.
pause
()
th
.
start
()
self
.
training_threads
.
append
(
th
)
self
.
async_running
=
False
def
run_step
(
self
):
if
not
self
.
async_running
:
self
.
async_running
=
True
for
th
in
self
.
training_threads
:
# resume all threads
th
.
resume
()
next
(
self
.
async_step_counter
)
return
super
(
AsyncMultiGPUTrainer
,
self
)
.
run_step
()
def
_trigger_epoch
(
self
):
self
.
async_running
=
False
for
th
in
self
.
training_threads
:
th
.
pause
()
try
:
if
self
.
config
.
nr_tower
>
1
:
async_step_total_cnt
=
int
(
re
.
findall
(
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
self
.
monitors
.
put
(
'async_global_step'
,
async_step_total_cnt
)
except
:
logger
.
exception
(
"Cannot log async_global_step"
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
_trigger_epoch
()
train_ops
=
[]
opt
=
self
.
model
.
get_optimizer
()
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
raw_devices
[
i
]):
grad_and_vars
=
grad_list
[
i
]
train_ops
.
append
(
opt
.
apply_gradients
(
grad_and_vars
,
name
=
'apply_grad_{}'
.
format
(
i
)))
self
.
train_op
=
tf
.
group
(
*
train_ops
,
name
=
'train_op'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment