Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
22b91be9
Commit
22b91be9
authored
Oct 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
map_arg for gpus
parent
dc709e94
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
20 deletions
+30
-20
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+1
-2
examples/DynamicFilterNetwork/steering-filter.py
examples/DynamicFilterNetwork/steering-filter.py
+1
-1
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+1
-1
examples/HED/hed.py
examples/HED/hed.py
+1
-1
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+26
-15
No files found.
examples/DoReFa-Net/alexnet-dorefa.py
View file @
22b91be9
...
...
@@ -322,5 +322,4 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
nr_tower
=
nr_tower
launch_train_with_config
(
configi
,
SyncMultiGPUTrainer
(
list
(
range
(
nr_tower
))))
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
nr_tower
))
examples/DynamicFilterNetwork/steering-filter.py
View file @
22b91be9
...
...
@@ -264,4 +264,4 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
list
(
range
(
NR_GPU
))
))
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
NR_GPU
))
examples/FasterRCNN/train.py
View file @
22b91be9
...
...
@@ -302,5 +302,5 @@ if __name__ == '__main__':
max_epoch
=
205000
*
factor
//
stepnum
,
session_init
=
get_model_loader
(
args
.
load
)
if
args
.
load
else
None
,
)
trainer
=
SyncMultiGPUTrainerReplicated
(
range
(
len
(
get_nr_gpu
())
))
trainer
=
SyncMultiGPUTrainerReplicated
(
get_nr_gpu
(
))
launch_train_with_config
(
cfg
,
trainer
)
examples/HED/hed.py
View file @
22b91be9
...
...
@@ -234,4 +234,4 @@ if __name__ == '__main__':
config
.
session_init
=
get_model_loader
(
args
.
load
)
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
range
(
max
(
get_nr_gpu
(),
1
)
)))
SyncMultiGPUTrainer
(
max
(
get_nr_gpu
(),
1
)))
tensorpack/train/trainers.py
View file @
22b91be9
...
...
@@ -8,6 +8,7 @@ from ..callbacks.graph import RunOp
from
..tfutils.sesscreate
import
NewSessionCreator
from
..utils
import
logger
from
..utils.argtools
import
map_arg
from
..tfutils
import
get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..tfutils.tower
import
TowerContext
...
...
@@ -31,6 +32,12 @@ __all__ = ['SimpleTrainer',
'DistributedTrainerReplicated'
]
def
_int_to_range
(
x
):
if
isinstance
(
x
,
int
):
assert
x
>
0
,
x
return
list
(
range
(
x
))
class
SimpleTrainer
(
SingleCostTrainer
):
"""
Single-GPU single-cost single-tower trainer.
...
...
@@ -54,13 +61,14 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
__doc__
=
SyncMultiGPUParameterServerBuilder
.
__doc__
def
__init__
(
self
,
towers
,
ps_device
=
'gpu'
):
@
map_arg
(
gpus
=
_int_to_range
)
def
__init__
(
self
,
gpus
,
ps_device
=
'gpu'
):
"""
Args:
tower
s ([int]): list of GPU ids.
gpu
s ([int]): list of GPU ids.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
"""
self
.
_builder
=
SyncMultiGPUParameterServerBuilder
(
tower
s
,
ps_device
)
self
.
_builder
=
SyncMultiGPUParameterServerBuilder
(
gpu
s
,
ps_device
)
super
(
SyncMultiGPUTrainerParameterServer
,
self
)
.
__init__
()
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
...
...
@@ -69,28 +77,29 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
return
[]
def
SyncMultiGPUTrainer
(
tower
s
):
def
SyncMultiGPUTrainer
(
gpu
s
):
"""
Return a default multi-GPU trainer, if you don't care about the details.
It may not be the most efficient one for your task.
Args:
tower
s (list[int]): list of GPU ids.
gpu
s (list[int]): list of GPU ids.
"""
return
SyncMultiGPUTrainerParameterServer
(
tower
s
,
ps_device
=
'gpu'
)
return
SyncMultiGPUTrainerParameterServer
(
gpu
s
,
ps_device
=
'gpu'
)
class
AsyncMultiGPUTrainer
(
SingleCostTrainer
):
__doc__
=
AsyncMultiGPUBuilder
.
__doc__
def
__init__
(
self
,
towers
,
scale_gradient
=
True
):
@
map_arg
(
gpus
=
_int_to_range
)
def
__init__
(
self
,
gpus
,
scale_gradient
=
True
):
"""
Args:
tower
s ([int]): list of GPU ids.
gpu
s ([int]): list of GPU ids.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
self
.
_builder
=
AsyncMultiGPUBuilder
(
tower
s
,
scale_gradient
)
self
.
_builder
=
AsyncMultiGPUBuilder
(
gpu
s
,
scale_gradient
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
()
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
...
...
@@ -103,12 +112,13 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
__doc__
=
SyncMultiGPUReplicatedBuilder
.
__doc__
def
__init__
(
self
,
towers
):
@
map_arg
(
gpus
=
_int_to_range
)
def
__init__
(
self
,
gpus
):
"""
Args:
tower
s ([int]): list of GPU ids.
gpu
s ([int]): list of GPU ids.
"""
self
.
_builder
=
SyncMultiGPUReplicatedBuilder
(
tower
s
)
self
.
_builder
=
SyncMultiGPUReplicatedBuilder
(
gpu
s
)
super
(
SyncMultiGPUTrainerReplicated
,
self
)
.
__init__
()
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
...
...
@@ -125,10 +135,11 @@ class DistributedTrainerReplicated(SingleCostTrainer):
__doc__
=
DistributedReplicatedBuilder
.
__doc__
def
__init__
(
self
,
towers
,
server
):
@
map_arg
(
gpus
=
_int_to_range
)
def
__init__
(
self
,
gpus
,
server
):
"""
Args:
tower
s (list[int]): list of GPU ids.
gpu
s (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers.
The job_name must be 'worker' because 'ps' job doesn't need to
build any graph.
...
...
@@ -139,7 +150,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
if
self
.
job_name
==
'worker'
:
# ps doesn't build any graph
self
.
_builder
=
DistributedReplicatedBuilder
(
tower
s
,
server
)
self
.
_builder
=
DistributedReplicatedBuilder
(
gpu
s
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
else
:
self
.
is_chief
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment