Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ce709fa3
Commit
ce709fa3
authored
Oct 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix multigpu training
parent
694e404b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
17 deletions
+30
-17
tensorpack/graph_builder/distributed.py
tensorpack/graph_builder/distributed.py
+2
-3
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+15
-11
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+12
-2
No files found.
tensorpack/graph_builder/distributed.py
View file @
ce709fa3
...
...
@@ -28,7 +28,6 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
self
.
task_index
=
server_def
.
task_index
# TODO XXX ps does't need to build!
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
self
.
job_name
assert
tf
.
test
.
is_gpu_available
()
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server_def
.
cluster
))
logger
.
info
(
"My role in the cluster: job={}, task={}"
.
format
(
self
.
job_name
,
self
.
task_index
))
...
...
@@ -176,8 +175,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
return
grads
# Ngpu * Nvar * 2
grad_list
=
self
.
build_on_multi_tower
(
get_grads
,
grad_list
=
DataParallelBuilder
.
build_on_towers
(
self
.
towers
,
get_grads
,
devices
=
self
.
raw_devices
,
use_vs
=
[
True
]
*
len
(
self
.
towers
))
# open vs at each tower
DataParallelBuilder
.
_check_grad_list
(
grad_list
)
...
...
tensorpack/graph_builder/training.py
View file @
ce709fa3
...
...
@@ -71,7 +71,7 @@ class DataParallelBuilder(GraphBuilder):
self
.
towers
=
towers
@
staticmethod
def
_check_tf_version
(
self
):
def
_check_tf_version
():
assert
get_tf_version_number
()
>=
1.1
,
\
"TF version {} is too old to run multi GPU training!"
.
format
(
tf
.
VERSION
)
...
...
@@ -84,9 +84,12 @@ class DataParallelBuilder(GraphBuilder):
nvars
=
[
len
(
k
)
for
k
in
grad_list
]
assert
len
(
set
(
nvars
))
==
1
,
"Number of gradients from each tower is different! "
+
str
(
nvars
)
def
build_on_multi_tower
(
self
,
func
,
devices
=
None
,
use_vs
=
None
):
@
staticmethod
def
build_on_towers
(
towers
,
func
,
devices
=
None
,
use_vs
=
None
):
"""
Run `func` on all towers.
Args:
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in ``towers``.
...
...
@@ -98,13 +101,13 @@ class DataParallelBuilder(GraphBuilder):
ret
=
[]
if
devices
is
not
None
:
assert
len
(
devices
)
==
len
(
self
.
towers
)
assert
len
(
devices
)
==
len
(
towers
)
if
use_vs
is
not
None
:
assert
len
(
use_vs
)
==
len
(
self
.
towers
)
assert
len
(
use_vs
)
==
len
(
towers
)
tower_names
=
[
'tower{}'
.
format
(
idx
)
for
idx
in
range
(
len
(
self
.
towers
))]
tower_names
=
[
'tower{}'
.
format
(
idx
)
for
idx
in
range
(
len
(
towers
))]
for
idx
,
t
in
enumerate
(
self
.
towers
):
for
idx
,
t
in
enumerate
(
towers
):
device
=
devices
[
idx
]
if
devices
is
not
None
else
'/gpu:{}'
.
format
(
t
)
usevs
=
use_vs
[
idx
]
if
use_vs
is
not
None
else
False
with
tf
.
device
(
device
),
TowerContext
(
...
...
@@ -177,7 +180,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
grad_list
=
self
.
build_on_multi_tower
(
get_grads
,
devices
)
grad_list
=
DataParallelBuilder
.
build_on_towers
(
self
.
towers
,
get_grads
,
devices
)
DataParallelBuilder
.
_check_grad_list
(
grad_list
)
# debug tower performance (without update):
...
...
@@ -237,7 +240,8 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
grad_list
=
self
.
build_on_multi_tower
(
grad_list
=
DataParallelBuilder
.
build_on_towers
(
self
.
towers
,
get_grads
,
# use no variable scope for the first tower
use_vs
=
[
False
]
+
[
True
]
*
(
len
(
self
.
towers
)
-
1
))
grads
=
SyncMultiGPUReplicatedBuilder
.
_allreduce_grads
(
grad_list
)
...
...
@@ -316,10 +320,10 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
grad_list
=
self
.
build_on_multi_tower
(
get_grads
,
devices
)
grad_list
=
DataParallelBuilder
.
build_on_towers
(
self
.
towers
,
get_grads
,
devices
)
DataParallelBuilder
.
_check_grad_list
(
grad_list
)
if
self
.
scale_gradient
and
len
(
self
.
towers
)
>
1
:
if
self
.
_
scale_gradient
and
len
(
self
.
towers
)
>
1
:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradproc
=
ScaleGradient
((
'.*'
,
1.0
/
len
(
self
.
towers
)),
verbose
=
False
)
...
...
tensorpack/train/distributed.py
View file @
ce709fa3
...
...
@@ -63,7 +63,7 @@ class DistributedTrainerReplicated(Trainer):
assert
config
.
data
is
not
None
and
config
.
model
is
not
None
self
.
server
=
server
self
.
_builder
=
DistributedReplicatedBuilder
(
self
.
config
.
tower
,
server
)
self
.
_builder
=
DistributedReplicatedBuilder
(
config
.
tower
,
server
)
self
.
_input_source
=
config
.
data
...
...
tensorpack/train/multigpu.py
View file @
ce709fa3
...
...
@@ -11,15 +11,25 @@ from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyC
from
..graph_builder.training
import
(
SyncMultiGPUParameterServerBuilder
,
SyncMultiGPUReplicatedBuilder
,
AsyncMultiGPUBuilder
)
AsyncMultiGPUBuilder
,
DataParallelBuilder
)
from
.base
import
Trainer
__all__
=
[
'SyncMultiGPUTrainerReplicated'
,
__all__
=
[
'MultiGPUTrainerBase'
,
'SyncMultiGPUTrainerReplicated'
,
'SyncMultiGPUTrainerParameterServer'
,
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
MultiGPUTrainerBase
(
Trainer
):
"""
For backward compatibility only
"""
def
build_on_multi_tower
(
towers
,
func
,
devices
=
None
,
use_vs
=
None
):
DataParallelBuilder
.
build_on_towers
(
towers
,
func
,
devices
,
use_vs
)
def
apply_prefetch_policy
(
config
,
gpu_prefetch
=
True
):
assert
(
config
.
data
is
not
None
or
config
.
dataflow
is
not
None
)
and
config
.
model
is
not
None
if
config
.
data
is
None
and
config
.
dataflow
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment