Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
206e1a67
Commit
206e1a67
authored
May 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
some docs change
parent
843ab15c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
13 deletions
+29
-13
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+14
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+15
-12
No files found.
tensorpack/train/input_source.py
View file @
206e1a67
...
...
@@ -153,7 +153,7 @@ class QueueInput(FeedfreeInput):
def
size
(
self
):
return
self
.
ds
.
size
()
# TODO
XXX
use input data mapping. not all placeholders are needed
# TODO use input data mapping. not all placeholders are needed
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
...
...
@@ -335,7 +335,14 @@ class ZMQInput(FeedfreeInput):
class
StagingInputWrapper
(
FeedfreeInput
):
"""
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs).
"""
class
StagingCallback
(
Callback
):
"""
A callback registered by this input source, to make sure stage/unstage
is run at each step.
"""
def
__init__
(
self
,
stage_op
,
unstage_op
,
nr_stage
):
self
.
nr_stage
=
nr_stage
self
.
stage_op
=
stage_op
...
...
@@ -351,6 +358,12 @@ class StagingInputWrapper(FeedfreeInput):
return
self
.
fetches
def
__init__
(
self
,
input
,
devices
,
nr_stage
=
5
):
"""
Args:
input: a :class:`FeedfreeInput`
devices: list of devices to be used for each training tower
nr_stage: number of elements to prefetch
"""
self
.
_input
=
input
assert
isinstance
(
input
,
FeedfreeInput
)
self
.
_devices
=
devices
...
...
tensorpack/train/multigpu.py
View file @
206e1a67
...
...
@@ -20,10 +20,11 @@ from .base import Trainer
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.input_source
import
QueueInput
,
StagingInputWrapper
__all__
=
[
'SyncMultiGPUTrainer'
,
'AsyncMultiGPUTrainer'
]
__all__
=
[
'MultiGPUTrainerBase'
,
'SyncMultiGPUTrainer'
,
'AsyncMultiGPUTrainer'
,
'LeastLoadedDeviceSetter'
]
class
MultiGPUTrainer
(
Trainer
):
class
MultiGPUTrainer
Base
(
Trainer
):
""" Base class for multi-gpu training"""
@
staticmethod
def
build_on_multi_tower
(
towers
,
func
,
devices
=
None
):
...
...
@@ -32,6 +33,9 @@ class MultiGPUTrainer(Trainer):
towers: list of gpu relative ids
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers.
Returns:
List of outputs of ``func``, evaluated on each tower.
"""
logger
.
info
(
"Training a model of {} tower"
.
format
(
len
(
towers
)))
...
...
@@ -58,14 +62,13 @@ class MultiGPUTrainer(Trainer):
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class
ParamServer
DeviceSetter
(
object
):
"""Helper class to assign variables on the least loaded ps-device."""
class
LeastLoaded
DeviceSetter
(
object
):
"""
Helper class to assign variables on the least loaded ps-device."""
def
__init__
(
self
,
worker_device
,
ps_devices
):
"""
Args:
worker_device: the device to use for computer ops.
ps_devices: a list of device to use for Variable ops. Each variable is
assigned to the least loaded device.
worker_device: the device to use for compute ops.
ps_devices: a list of device to use for Variable ops.
"""
self
.
ps_devices
=
ps_devices
self
.
worker_device
=
worker_device
...
...
@@ -86,7 +89,7 @@ class ParamServerDeviceSetter(object):
return
device_name
class
SyncMultiGPUTrainerParameterServer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
):
class
SyncMultiGPUTrainerParameterServer
(
MultiGPUTrainer
Base
,
SingleCostFeedfreeTrainer
):
"""
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored on PS.
...
...
@@ -148,12 +151,12 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainer, SingleCostFeedfreeTrai
raw_devices
=
[
'/gpu:{}'
.
format
(
k
)
for
k
in
self
.
config
.
tower
]
if
self
.
_ps_device
==
'gpu'
:
devices
=
[
ParamServer
DeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
devices
=
[
LeastLoaded
DeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
else
:
devices
=
[
tf
.
train
.
replica_device_setter
(
worker_device
=
d
,
ps_device
=
'/cpu:0'
,
ps_tasks
=
1
)
for
d
in
raw_devices
]
grad_list
=
MultiGPUTrainer
.
build_on_multi_tower
(
grad_list
=
MultiGPUTrainer
Base
.
build_on_multi_tower
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
],
devices
)
# debug tower performance (without update):
...
...
@@ -175,7 +178,7 @@ def SyncMultiGPUTrainer(config):
return
SyncMultiGPUTrainerParameterServer
(
config
,
ps_device
=
'gpu'
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
Base
,
SingleCostFeedfreeTrainer
):
"""
A multi-tower multi-GPU trainer where each tower independently
...
...
@@ -204,7 +207,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def
_setup
(
self
):
super
(
AsyncMultiGPUTrainer
,
self
)
.
_setup
()
grad_list
=
MultiGPUTrainer
.
build_on_multi_tower
(
grad_list
=
MultiGPUTrainer
Base
.
build_on_multi_tower
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
grad_list
=
[
FilterNoneGrad
()
.
process
(
gv
)
for
gv
in
grad_list
]
if
self
.
_scale_gradient
and
self
.
config
.
nr_tower
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment