Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1a262e8c
Commit
1a262e8c
authored
Oct 31, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Initial HorovodTrainer (#422)
parent
7b0782d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
3 deletions
+30
-3
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+29
-2
No files found.
tensorpack/train/base.py
View file @
1a262e8c
...
...
@@ -174,7 +174,7 @@ class Trainer(object):
logger
.
info
(
"Initializing the session ..."
)
session_init
.
_run_init
(
self
.
sess
)
else
:
if
not
isinstance
(
se
lf
.
_config
.
se
ssion_init
,
JustCurrentSession
):
if
not
isinstance
(
session_init
,
JustCurrentSession
):
logger
.
warn
(
"This is not a chief worker, 'session_init' was ignored!"
)
self
.
sess
.
graph
.
finalize
()
...
...
tensorpack/train/trainers.py
View file @
1a262e8c
...
...
@@ -4,7 +4,7 @@
import
os
from
..callbacks
.graph
import
RunOp
from
..callbacks
import
RunOp
from
..tfutils.sesscreate
import
NewSessionCreator
from
..utils
import
logger
...
...
@@ -29,7 +29,8 @@ __all__ = ['SimpleTrainer',
'SyncMultiGPUTrainerReplicated'
,
'SyncMultiGPUTrainerParameterServer'
,
'AsyncMultiGPUTrainer'
,
'DistributedTrainerReplicated'
]
'DistributedTrainerReplicated'
,
'HorovodTrainer'
]
def
_int_to_range
(
x
):
...
...
@@ -206,3 +207,29 @@ class DistributedTrainerReplicated(SingleCostTrainer):
@
property
def
_main_tower_vs_name
(
self
):
return
"tower0"
class
HorovodTrainer
(
SingleCostTrainer
):
def
__init__
(
self
):
hvd
.
init
()
self
.
is_chief
=
hvd
.
rank
()
==
0
logger
.
info
(
"Horovod local rank: {}"
.
format
(
hvd
.
local_rank
()))
super
(
HorovodTrainer
,
self
)
.
__init__
()
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
with
TowerContext
(
''
,
is_training
=
True
):
grads
=
self
.
_make_get_grad_fn
(
input
,
get_cost_fn
,
get_opt_fn
)()
opt
=
get_opt_fn
()
opt
=
hvd
.
DistributedOptimizer
(
opt
)
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
return
[
RunOp
(
hvd
.
broadcast_global_variables
(
0
),
run_before
=
True
,
run_as_trigger
=
False
,
verbose
=
True
)]
from
..utils.develop
import
create_dummy_class
# noqa
try
:
import
horovod.tensorflow
as
hvd
except
ImportError
:
HorovodTrainer
=
create_dummy_class
(
'HovorodTrainer'
,
'horovod'
)
# noqa
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment