Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4eb2be09
Commit
4eb2be09
authored
Jan 29, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix horovod training; import pyarrow without torch
parent
0086e156
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
7 deletions
+34
-7
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+5
-3
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+18
-1
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+11
-3
No files found.
tensorpack/callbacks/saver.py
View file @
4eb2be09
...
...
@@ -37,12 +37,14 @@ class ModelSaver(Callback):
self
.
var_collections
=
var_collections
if
checkpoint_dir
is
None
:
checkpoint_dir
=
logger
.
get_logger_dir
()
assert
checkpoint_dir
is
not
None
if
not
tf
.
gfile
.
IsDirectory
(
checkpoint_dir
):
tf
.
gfile
.
MakeDirs
(
checkpoint_dir
)
if
checkpoint_dir
is
not
None
:
if
not
tf
.
gfile
.
IsDirectory
(
checkpoint_dir
):
tf
.
gfile
.
MakeDirs
(
checkpoint_dir
)
self
.
checkpoint_dir
=
checkpoint_dir
def
_setup_graph
(
self
):
assert
self
.
checkpoint_dir
is
not
None
,
\
"ModelSaver() doesn't have a valid checkpoint directory."
vars
=
[]
for
key
in
self
.
var_collections
:
vars
.
extend
(
tf
.
get_collection
(
key
))
...
...
tensorpack/train/trainers.py
View file @
4eb2be09
...
...
@@ -292,11 +292,26 @@ class HorovodTrainer(SingleCostTrainer):
logger
.
info
(
"Horovod local rank={}"
.
format
(
self
.
_local_rank
))
super
(
HorovodTrainer
,
self
)
.
__init__
()
def
allreduce
(
self
,
grads
):
if
hvd
.
size
()
==
1
:
return
grads
# copied from https://github.com/uber/horovod/blob/master/horovod/tensorflow/__init__.py
averaged_gradients
=
[]
with
tf
.
name_scope
(
"HVDAllReduce"
):
for
grad
,
var
in
grads
:
if
grad
is
not
None
:
avg_grad
=
hvd
.
allreduce
(
grad
,
average
=
True
)
averaged_gradients
.
append
((
avg_grad
,
var
))
else
:
averaged_gradients
.
append
((
None
,
var
))
return
averaged_gradients
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
with
TowerContext
(
''
,
is_training
=
True
):
grads
=
self
.
_make_get_grad_fn
(
input
,
get_cost_fn
,
get_opt_fn
)()
grads
=
self
.
allreduce
(
grads
)
opt
=
get_opt_fn
()
opt
=
hvd
.
DistributedOptimizer
(
opt
)
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
with
tf
.
name_scope
(
'horovod_broadcast'
):
op
=
hvd
.
broadcast_global_variables
(
0
)
...
...
@@ -311,6 +326,8 @@ class HorovodTrainer(SingleCostTrainer):
if
not
isinstance
(
session_creator
,
NewSessionCreator
):
raise
ValueError
(
"session_creator has to be `NewSessionCreator` for horovod training! "
)
# NOTE It will fail if GPU was already detected before initializing the session
# https://github.com/tensorflow/tensorflow/issues/8136
session_creator
.
config
.
gpu_options
.
visible_device_list
=
str
(
self
.
_local_rank
)
super
(
HorovodTrainer
,
self
)
.
initialize
(
session_creator
,
session_init
)
...
...
tensorpack/utils/serialize.py
View file @
4eb2be09
...
...
@@ -8,9 +8,13 @@ import msgpack_numpy
msgpack_numpy
.
patch
()
try
:
import
sys
sys
.
modules
[
'torch'
]
=
None
# https://github.com/apache/arrow/pull/1223#issuecomment-359895666
import
pyarrow
as
pa
del
sys
.
modules
[
'torch'
]
except
ImportError
:
pa
ss
pa
=
None
__all__
=
[
'loads'
,
'dumps'
]
...
...
@@ -51,5 +55,9 @@ def loads_pyarrow(buf):
return
pa
.
deserialize
(
buf
)
loads
=
loads_msgpack
dumps
=
dumps_msgpack
if
pa
is
None
:
loads
=
loads_msgpack
dumps
=
dumps_msgpack
else
:
loads
=
loads_pyarrow
dumps
=
dumps_pyarrow
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment