Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
46991853
Commit
46991853
authored
Dec 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
simplify code for distributed
parent
2d6d7ad4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
96 deletions
+60
-96
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+21
-29
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+39
-66
tensorpack/trainv1/utility.py
tensorpack/trainv1/utility.py
+0
-1
No files found.
tensorpack/graph_builder/utils.py
View file @
46991853
...
...
@@ -10,7 +10,7 @@ import tensorflow as tf
__all__
=
[
'LeastLoadedDeviceSetter'
,
'OverrideCachingDevice'
,
'
OverrideToLocalVariable'
,
'
override_to_local_variable'
,
'override_to_local_variable'
,
'allreduce_grads'
,
'average_grads'
]
...
...
@@ -20,35 +20,34 @@ Some utilities for building the graph.
"""
def
_replace_global_by_local
(
kwargs
):
if
'collections'
in
kwargs
:
collections
=
kwargs
[
'collections'
]
if
not
collections
:
collections
=
set
([
tf
.
GraphKeys
.
GLOBAL_VARIABLES
])
else
:
collections
=
set
(
collections
.
copy
())
collections
.
remove
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
collections
.
add
(
tf
.
GraphKeys
.
LOCAL_VARIABLES
)
kwargs
[
'collections'
]
=
list
(
collections
)
@
contextmanager
def
override_to_local_variable
(
enable
=
True
):
if
enable
:
def
custom_getter
(
getter
,
name
,
*
args
,
**
kwargs
):
_replace_global_by_local
(
kwargs
)
return
getter
(
name
,
*
args
,
**
kwargs
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
custom_getter
=
OverrideToLocalVariable
()
):
custom_getter
=
custom_getter
):
yield
else
:
yield
class
OverrideToLocalVariable
(
object
):
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def
__call__
(
self
,
getter
,
name
,
*
args
,
**
kwargs
):
if
'collections'
in
kwargs
:
collections
=
kwargs
[
'collections'
]
if
not
collections
:
collections
=
set
([
tf
.
GraphKeys
.
GLOBAL_VARIABLES
])
else
:
collections
=
set
(
collections
.
copy
())
collections
.
remove
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
collections
.
add
(
tf
.
GraphKeys
.
LOCAL_VARIABLES
)
kwargs
[
'collections'
]
=
list
(
collections
)
return
getter
(
name
,
*
args
,
**
kwargs
)
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L192-L218
class
LeastLoadedDeviceSetter
(
object
):
""" Helper class to assign variables on the least loaded ps-device."""
...
...
@@ -170,15 +169,8 @@ class OverrideCachingDevice(object):
def
__call__
(
self
,
getter
,
*
args
,
**
kwargs
):
size
=
tf
.
TensorShape
(
kwargs
[
'shape'
])
.
num_elements
()
if
size
is
None
or
not
kwargs
.
get
(
'trainable'
,
True
):
# TODO
collections
=
kwargs
[
'collections'
]
if
not
collections
:
collections
=
set
([
tf
.
GraphKeys
.
GLOBAL_VARIABLES
])
else
:
collections
=
set
(
collections
.
copy
())
collections
.
remove
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
collections
.
add
(
tf
.
GraphKeys
.
LOCAL_VARIABLES
)
kwargs
[
'collections'
]
=
list
(
collections
)
# TODO a lot of vars won't be saved then
_replace_global_by_local
(
kwargs
)
return
getter
(
*
args
,
**
kwargs
)
if
size
<
self
.
small_variable_size_threshold
:
...
...
tensorpack/train/trainers.py
View file @
46991853
...
...
@@ -157,46 +157,23 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
return
[
cb
]
class
DistributedTrainerParameterServer
(
SingleCostTrainer
):
__doc__
=
DistributedParameterServerBuilder
.
__doc__
class
DistributedTrainerBase
(
SingleCostTrainer
):
devices
=
None
"""
List of GPU ids.
"""
# TODO use full device name instead of id
@
map_arg
(
gpus
=
_int_to_range
)
def
__init__
(
self
,
gpus
,
server
,
caching_device
=
'cpu'
):
"""
Args:
gpus ([int]): list of GPU ids.
"""
def
__init__
(
self
,
gpus
,
server
):
super
(
DistributedTrainerBase
,
self
)
.
__init__
()
self
.
devices
=
gpus
self
.
server
=
server
self
.
job_name
=
server
.
server_def
.
job_name
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
self
.
job_name
if
self
.
job_name
==
'worker'
:
# ps doesn't build any graph
self
.
_builder
=
DistributedParameterServerBuilder
(
gpus
,
server
,
caching_device
)
self
.
is_chief
=
self
.
_builder
.
is_chief
else
:
self
.
is_chief
=
False
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server
.
server_def
.
cluster
))
super
(
DistributedTrainerParameterServer
,
self
)
.
__init__
()
if
self
.
job_name
==
'ps'
:
# ps shouldn't setup input either
logger
.
info
(
"Running ps {}"
.
format
(
self
.
server
.
server_def
.
task_index
))
logger
.
info
(
"Kill me with 'kill {}'"
.
format
(
os
.
getpid
()))
self
.
server
.
join
()
# this function will never return tensorflow#4713
raise
RuntimeError
(
"This is a bug. Server.join() for ps should never return!"
)
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
self
.
train_op
=
self
.
_builder
.
build
(
self
.
_make_get_grad_fn
(
input
,
get_cost_fn
,
get_opt_fn
),
get_opt_fn
)
return
[]
def
join
(
self
):
logger
.
info
(
"Calling server.join() on {}:{}"
.
format
(
self
.
job_name
,
self
.
server
.
server_def
.
task_index
))
logger
.
info
(
"Kill me with 'kill {}'"
.
format
(
os
.
getpid
()))
self
.
server
.
join
()
# this function will never return tensorflow#4713
raise
RuntimeError
(
"This is a bug. Server.join() for should never return!"
)
@
HIDE_DOC
def
initialize
(
self
,
session_creator
,
session_init
):
...
...
@@ -205,19 +182,38 @@ class DistributedTrainerParameterServer(SingleCostTrainer):
raise
ValueError
(
"You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server."
)
super
(
DistributedTrainer
ParameterServer
,
self
)
.
initialize
(
super
(
DistributedTrainer
Base
,
self
)
.
initialize
(
get_distributed_session_creator
(
self
.
server
),
session_init
)
class
DistributedTrainerParameterServer
(
DistributedTrainerBase
):
__doc__
=
DistributedParameterServerBuilder
.
__doc__
@
map_arg
(
gpus
=
_int_to_range
)
def
__init__
(
self
,
gpus
,
server
,
caching_device
=
'cpu'
):
"""
Args:
gpus ([int]): list of GPU ids.
"""
super
(
DistributedTrainerParameterServer
,
self
)
.
__init__
(
gpus
,
server
)
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
self
.
job_name
if
self
.
job_name
==
'ps'
:
self
.
join
()
self
.
_builder
=
DistributedParameterServerBuilder
(
gpus
,
server
,
caching_device
)
self
.
is_chief
=
self
.
_builder
.
is_chief
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
self
.
train_op
=
self
.
_builder
.
build
(
self
.
_make_get_grad_fn
(
input
,
get_cost_fn
,
get_opt_fn
),
get_opt_fn
)
return
[]
class
DistributedTrainerReplicated
(
SingleCostTrainer
):
__doc__
=
DistributedReplicatedBuilder
.
__doc__
devices
=
None
"""
List of GPU ids.
"""
@
map_arg
(
gpus
=
_int_to_range
)
def
__init__
(
self
,
gpus
,
server
):
"""
...
...
@@ -225,26 +221,13 @@ class DistributedTrainerReplicated(SingleCostTrainer):
gpus (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers.
"""
self
.
devices
=
gpus
self
.
server
=
server
self
.
job_name
=
server
.
server_def
.
job_name
super
(
DistributedTrainerReplicated
,
self
)
.
__init__
(
gpus
,
server
)
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
self
.
job_name
if
self
.
job_name
==
'worker'
:
# ps doesn't build any graph
self
.
_builder
=
DistributedReplicatedBuilder
(
gpus
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
else
:
self
.
is_chief
=
False
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server
.
server_def
.
cluster
))
super
(
DistributedTrainerReplicated
,
self
)
.
__init__
()
if
self
.
job_name
==
'ps'
:
# ps shouldn't setup input either
logger
.
info
(
"Running ps {}"
.
format
(
self
.
server
.
server_def
.
task_index
))
logger
.
info
(
"Kill me with 'kill {}'"
.
format
(
os
.
getpid
()))
self
.
server
.
join
()
# this function will never return tensorflow#4713
raise
RuntimeError
(
"This is a bug. Server.join() for ps should never return!"
)
self
.
join
()
self
.
_builder
=
DistributedReplicatedBuilder
(
gpus
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
def
_setup_input
(
self
,
inputs_desc
,
input
):
with
override_to_local_variable
():
...
...
@@ -276,16 +259,6 @@ class DistributedTrainerReplicated(SingleCostTrainer):
callbacks
.
append
(
cb
)
return
callbacks
@
HIDE_DOC
def
initialize
(
self
,
session_creator
,
session_init
):
if
not
isinstance
(
session_creator
,
NewSessionCreator
)
or
\
session_creator
.
user_provided_config
:
raise
ValueError
(
"You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server."
)
super
(
DistributedTrainerReplicated
,
self
)
.
initialize
(
get_distributed_session_creator
(
self
.
server
),
session_init
)
@
property
def
_main_tower_vs_name
(
self
):
return
"tower0"
...
...
tensorpack/trainv1/utility.py
View file @
46991853
...
...
@@ -4,5 +4,4 @@
# for backwards-compatibility
from
..graph_builder.utils
import
(
# noqa
OverrideToLocalVariable
,
override_to_local_variable
,
LeastLoadedDeviceSetter
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment