Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5ad33556
Commit
5ad33556
authored
Mar 14, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Hierarchical AllReduce
parent
d4a432ad
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
165 additions
and
55 deletions
+165
-55
docs/conf.py
docs/conf.py
+14
-8
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+8
-2
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+127
-37
tensorpack/tfutils/scope_utils.py
tensorpack/tfutils/scope_utils.py
+8
-5
tensorpack/train/tower.py
tensorpack/train/tower.py
+1
-1
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+7
-2
No files found.
docs/conf.py
View file @
5ad33556
...
...
@@ -364,23 +364,29 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
return
False
# hide deprecated stuff
if
name
in
[
'get_predictors'
,
# deprecated stuff:
'GaussianDeform'
,
'set_tower_func'
,
'TryResumeTraining'
,
# renamed stuff:
'dump_chkpt_vars'
,
'DumpTensor'
,
'DumpParamAsImage'
,
'StagingInputWrapper'
,
'set_tower_func'
,
'TryResumeTraining'
,
'LeakyReLU'
,
'PrefetchOnGPUs'
,
'PeriodicRunHooks'
,
'apply_default_prefetch'
,
'average_grads'
,
'Deconv2D'
,
# deprecated or renamed symbolic code
'Deconv2D'
,
'LeakyReLU'
,
'saliency_map'
,
'get_scalar_var'
,
'psnr'
,
'prediction_incorrect'
,
'huber_loss'
,
'SoftMax'
# internal only
'apply_default_prefetch'
,
'average_grads'
,
'aggregate_grads'
,
'allreduce_grads'
,
'PrefetchOnGPUs'
,
]:
return
True
if
name
in
[
'get_data'
,
'size'
,
'reset_state'
]:
...
...
tensorpack/graph_builder/training.py
View file @
5ad33556
...
...
@@ -16,7 +16,8 @@ from ..tfutils.gradproc import ScaleGradient
from
.utils
import
(
LeastLoadedDeviceSetter
,
override_to_local_variable
,
allreduce_grads
,
aggregate_grads
)
allreduce_grads
,
aggregate_grads
,
allreduce_hierarchical
,
split_grad_list
,
merge_grad_list
)
__all__
=
[
'GraphBuilder'
,
...
...
@@ -213,7 +214,12 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
DataParallelBuilder
.
_check_grad_list
(
grad_list
)
if
self
.
_mode
==
'nccl'
:
self
.
grads
=
allreduce_grads
(
grad_list
,
average
=
self
.
_average
)
# #gpu x #param x 2
all_grads
,
all_vars
=
split_grad_list
(
grad_list
)
if
True
:
all_grads
=
allreduce_grads
(
all_grads
,
average
=
self
.
_average
)
# #gpu x #param x 2
else
:
all_grads
=
allreduce_hierarchical
(
all_grads
,
raw_devices
,
average
=
self
.
_average
)
self
.
grads
=
merge_grad_list
(
all_grads
,
all_vars
)
elif
self
.
_mode
==
'cpu'
:
agg_grad_and_vars
=
aggregate_grads
(
grad_list
,
colocation
=
False
,
...
...
tensorpack/graph_builder/utils.py
View file @
5ad33556
...
...
@@ -8,6 +8,7 @@ import operator
import
tensorflow
as
tf
from
..tfutils.varreplace
import
custom_getter_scope
from
..tfutils.scope_utils
import
under_name_scope
__all__
=
[
'LeastLoadedDeviceSetter'
,
...
...
@@ -15,7 +16,8 @@ __all__ = ['LeastLoadedDeviceSetter',
'override_to_local_variable'
,
'allreduce_grads'
,
'average_grads'
,
'aggregate_grads'
]
'aggregate_grads'
]
"""
...
...
@@ -83,43 +85,132 @@ class LeastLoadedDeviceSetter(object):
return
"LeastLoadedDeviceSetter-{}"
.
format
(
self
.
worker_device
)
def
split_grad_list
(
grad_list
):
"""
Args:
grad_list: K x N x 2
Returns:
K x N: gradients
K x N: variables
"""
g
=
[]
v
=
[]
for
tower
in
grad_list
:
g
.
append
([
x
[
0
]
for
x
in
tower
])
v
.
append
([
x
[
1
]
for
x
in
tower
])
return
g
,
v
def
merge_grad_list
(
all_grads
,
all_vars
):
"""
Args:
all_grads (K x N): gradients
all_vars(K x N): variables
Return:
K x N x 2: list of list of (grad, var) pairs
"""
return
[
list
(
zip
(
gs
,
vs
))
for
gs
,
vs
in
zip
(
all_grads
,
all_vars
)]
@
under_name_scope
(
'AllReduceGrads'
)
def
allreduce_grads
(
all_grads
,
average
):
"""
All-reduce average the gradients among devices. Results are broadcasted to all devices.
All-reduce average the gradients among
K
devices. Results are broadcasted to all devices.
Args:
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
The variables have to be the same across the K lists.
all_grads (K x N): List of list of gradients. N is the number of variables.
average (bool): average gradients or not.
Returns:
(K x N x 2): same as input, but each grad is replaced by the average over K list
s.
K x N: same as input, but each grad is replaced by the average over K device
s.
"""
from
tensorflow.contrib
import
nccl
nr_tower
=
len
(
all_grads
)
if
nr_tower
==
1
:
return
all_grads
new_all_grads
=
[]
# NVar * NGPU * 2
with
tf
.
name_scope
(
'AvgGrad'
):
for
grad_and_vars
in
zip
(
*
all_grads
):
v
=
grad_and_vars
[
0
][
1
]
grads
=
[
g
for
g
,
_
in
grad_and_vars
]
new_all_grads
=
[]
# N x K
for
grads
in
zip
(
*
all_grads
):
summed
=
nccl
.
all_sum
(
grads
)
grads_for_a_var
=
[]
for
(
_
,
v
),
g
in
zip
(
grad_and_vars
,
summed
)
:
grads_for_devices
=
[]
# K
for
g
in
summed
:
with
tf
.
device
(
g
.
device
):
# tensorflow/benchmarks didn't average gradients
if
average
:
g
=
tf
.
multiply
(
g
,
1.0
/
nr_tower
)
grads_for_a_var
.
append
((
g
,
v
)
)
new_all_grads
.
append
(
grads_for_a_var
)
grads_for_devices
.
append
(
g
)
new_all_grads
.
append
(
grads_for_devices
)
# transpose
ret
=
[
k
for
k
in
zip
(
*
new_all_grads
)]
# transpose
to K x N
ret
=
list
(
zip
(
*
new_all_grads
))
return
ret
@
under_name_scope
(
'AllReduceGradsHierachical'
)
def
allreduce_hierarchical
(
all_grads
,
devices
,
average
=
False
):
"""
Hierarchical allreduce for DGX-1 system.
Args:
all_grads (K x N): List of list of gradients. N is the number of variables.
devices ([str]): K str for the K devices.
average (bool): average gradients or not.
Returns:
(K x N): same as input, but each grad is replaced by the average over K lists.
"""
num_gpu
=
len
(
devices
)
assert
num_gpu
==
8
,
num_gpu
assert
len
(
all_grads
)
==
num_gpu
,
len
(
all_grads
)
group_size
=
num_gpu
//
2
agg_all_grads
=
[]
# N x K
for
varid
,
grads
in
enumerate
(
zip
(
*
all_grads
)):
# grads: K gradients
g0_main_gpu
=
varid
%
num_gpu
g1_main_gpu
=
(
g0_main_gpu
+
group_size
)
%
num_gpu
g0_start
=
0
if
g0_main_gpu
<
group_size
else
group_size
g1_start
=
0
if
g1_main_gpu
<
group_size
else
group_size
assert
g0_start
!=
g1_start
g0_grads
=
grads
[
g0_start
:
g0_start
+
group_size
]
g1_grads
=
grads
[
g1_start
:
g1_start
+
group_size
]
with
tf
.
device
(
devices
[
g0_main_gpu
]):
g0_agg
=
tf
.
add_n
(
g0_grads
,
name
=
'group0_agg'
)
with
tf
.
device
(
devices
[
g1_main_gpu
]):
g1_agg
=
tf
.
add_n
(
g1_grads
,
name
=
'group1_agg'
)
g1_total_agg
=
tf
.
add
(
g0_agg
,
g1_agg
,
name
=
'group1_total_agg'
)
with
tf
.
device
(
devices
[
g0_main_gpu
]):
g0_total_agg
=
tf
.
identity
(
g1_total_agg
,
name
=
'group0_total_agg'
)
agg_grads
=
[]
# K aggregated grads
for
k
in
range
(
num_gpu
):
if
(
k
<
group_size
)
==
(
g0_main_gpu
<
group_size
):
main_gpu
=
g0_total_agg
else
:
main_gpu
=
g1_total_agg
with
tf
.
device
(
devices
[
k
]):
if
not
average
:
device_total_agg
=
tf
.
identity
(
main_gpu
,
name
=
'device{}_total_agg'
.
format
(
k
))
else
:
# TODO where to put average?
device_total_agg
=
tf
.
multiply
(
main_gpu
,
1.0
/
num_gpu
,
name
=
'device{}_total_agg'
.
format
(
k
))
agg_grads
.
append
(
device_total_agg
)
agg_all_grads
.
append
(
agg_grads
)
# transpose
agg_all_grads
=
list
(
zip
(
*
agg_all_grads
))
# K x Nvar
return
agg_all_grads
@
under_name_scope
(
'AggregateGrads'
)
def
aggregate_grads
(
all_grads
,
colocation
=
False
,
devices
=
None
,
...
...
@@ -153,7 +244,6 @@ def aggregate_grads(all_grads,
return
tf
.
add_n
(
grads
)
ret
=
[]
with
tf
.
name_scope
(
'AggregateGrad'
):
for
idx
,
grad_and_vars
in
enumerate
(
zip
(
*
all_grads
)):
# Ngpu * 2
v
=
grad_and_vars
[
0
][
1
]
...
...
tensorpack/tfutils/scope_utils.py
View file @
5ad33556
...
...
@@ -55,11 +55,11 @@ def auto_reuse_variable_scope(func):
return
wrapper
def
under_name_scope
():
def
under_name_scope
(
name
=
None
):
"""
Returns:
A decorator which makes the function happen under a name scope
,
which is named by
the function itself.
A decorator which makes the function happen under a name scope
.
The default name is
the function itself.
Examples:
...
...
@@ -77,8 +77,11 @@ def under_name_scope():
def
_impl
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
name
=
func
.
__name__
with
tf
.
name_scope
(
name
):
if
name
is
None
:
scopename
=
func
.
__name__
else
:
scopename
=
name
with
tf
.
name_scope
(
scopename
):
return
func
(
*
args
,
**
kwargs
)
return
wrapper
return
_impl
...
...
tensorpack/train/tower.py
View file @
5ad33556
...
...
@@ -38,7 +38,7 @@ class TowerTrainer(Trainer):
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
_tower_func
=
tower_func
@
deprecated
(
"Just use tower_func = xxx instead!"
)
@
deprecated
(
"Just use tower_func = xxx instead!"
,
"2018-06-01"
)
def
set_tower_func
(
self
,
tower_func
):
self
.
_set_tower_func
(
tower_func
)
...
...
tensorpack/train/trainers.py
View file @
5ad33556
...
...
@@ -140,17 +140,22 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
"""
@
map_arg
(
gpus
=
_int_to_range
)
def
__init__
(
self
,
gpus
,
average
=
True
,
mode
=
'nccl'
,
use_nccl
=
None
):
def
__init__
(
self
,
gpus
,
average
=
True
,
mode
=
None
,
use_nccl
=
None
):
"""
Args:
gpus (int or [int]): list of GPU ids.
average (bool): whether to average or sum gradients.
mode (str): Gradient aggregation mode. Supported values: ['nccl', 'cpu']
mode (str or None): Gradient aggregation mode.
These methods may have slight differences in speed.
Supported values: ['nccl', 'cpu']. Default to pick
automatically by heuristics.
"""
self
.
devices
=
gpus
if
use_nccl
is
not
None
:
mode
=
'nccl'
if
use_nccl
else
'cpu'
logger
.
warn
(
"use_nccl option was deprecated! Use the `mode` option instead!"
)
if
mode
is
None
:
mode
=
'nccl'
mode
=
mode
.
lower
()
self
.
_builder
=
SyncMultiGPUReplicatedBuilder
(
gpus
,
average
,
mode
)
super
(
SyncMultiGPUTrainerReplicated
,
self
)
.
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment