Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ee1af311
Commit
ee1af311
authored
Jun 01, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
comments & fix lint
parent
f1e3b3ae
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
26 deletions
+13
-26
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+13
-26
No files found.
tensorpack/train/distributed.py
View file @
ee1af311
...
...
@@ -6,7 +6,6 @@ import tensorflow as tf
from
six.moves
import
range
from
..utils
import
logger
from
.input_source
import
StagingInputWrapper
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.multigpu
import
MultiGPUTrainerBase
from
..callbacks
import
RunOp
...
...
@@ -19,18 +18,14 @@ __all__ = ['DistributedReplicatedTrainer']
PS_SHADOW_VAR_PREFIX
=
'ps_var'
# To be used with custom_getter on tf.get_variable. Ensures the created variable
# is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
class
OverrideToLocalVariableIfNotPsVar
(
object
):
# args and kwargs come from the custom_getter interface for Tensorflow
# variables, and matches tf.get_variable's signature, with the addition of
# 'getter' at the beginning.
"""
Ensures the created variable
is in LOCAL_VARIABLES and not GLOBAL_VARIBLES collection.
"""
def
__call__
(
self
,
getter
,
name
,
*
args
,
**
kwargs
):
if
name
.
startswith
(
PS_SHADOW_VAR_PREFIX
):
return
getter
(
*
args
,
**
kwargs
)
logger
.
info
(
"CustomGetter-{}"
.
format
(
name
))
if
'collections'
in
kwargs
:
collections
=
kwargs
[
'collections'
]
if
not
collections
:
...
...
@@ -50,7 +45,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self
.
cluster
=
tf
.
train
.
ClusterSpec
(
server_def
.
cluster
)
self
.
job_name
=
server_def
.
job_name
self
.
task_index
=
server_def
.
task_index
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
job_name
assert
self
.
job_name
in
[
'ps'
,
'worker'
],
self
.
job_name
assert
tf
.
test
.
is_gpu_available
self
.
_input_source
=
config
.
data
self
.
is_chief
=
(
self
.
task_index
==
0
and
self
.
job_name
==
'worker'
)
...
...
@@ -71,14 +67,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self
.
sync_queue_devices
=
[
'/job:ps/task:
%
s/cpu:0'
%
i
for
i
in
range
(
self
.
num_ps
)]
self
.
sync_queue_counter
=
0
if
self
.
nr_gpu
>
1
:
assert
tf
.
test
.
is_gpu_available
()
# TODO staging doesn't work with dummy (require context)
# seem to only improve on >1 GPUs
#if not isinstance(self._input_source, StagingInputWrapper):
#self._input_source = StagingInputWrapper(self._input_source, self.raw_devices)
@
staticmethod
def
_average_grads
(
tower_grads
,
devices
):
"""
...
...
@@ -134,7 +122,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
list of copy ops
"""
# TODO do this for each variable separately?
opt
=
self
.
model
.
get_optimizer
()
# TODO ensure it in global scope, not local
opt
=
self
.
model
.
get_optimizer
()
var_update_ops
=
[]
for
vid
,
(
g
,
v
)
in
enumerate
(
ps_var_grads
):
apply_gradient_op
=
opt
.
apply_gradients
([(
g
,
v
)])
...
...
@@ -175,7 +163,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
var_update_ops
=
self
.
_apply_gradients_and_copy
(
grad_list
,
ps_var_grads
)
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
'sync_queues_step_end'
,
[
main_fetch
])
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
'post_copy_barrier'
,
[
main_fetch
])
self
.
register_callback
(
RunOp
(
self
.
get_post_init_ops
,
run_before
=
True
,
run_as_trigger
=
False
))
...
...
@@ -189,6 +178,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to the tf.train.Server constructor."
)
# TODO use scaffold
class
SupervisedSessionCreator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
is_chief
,
target
):
self
.
is_chief
=
is_chief
...
...
@@ -224,14 +214,11 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
shared_name
=
'
%
s
%
s'
%
(
name_prefix
,
i
))
for
i
in
range
(
self
.
num_worker
)]
queue_ops
=
[]
# For each other worker, add an entry in a queue, signaling that it can
# finish this step.
# For each other worker, add an entry in a queue, signaling that it can finish this step.
token
=
tf
.
constant
(
False
)
with
tf
.
control_dependencies
(
enqueue_after_list
):
for
i
,
q
in
enumerate
(
sync_queues
):
if
i
==
self
.
task_index
:
queue_ops
.
append
(
tf
.
no_op
())
else
:
if
i
!=
self
.
task_index
:
queue_ops
.
append
(
q
.
enqueue
(
token
))
# Drain tokens off queue for this worker, one for each other worker.
...
...
@@ -256,7 +243,7 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
v
.
name
[
len
(
PS_SHADOW_VAR_PREFIX
+
'/'
):])
for
i
in
range
(
self
.
nr_gpu
):
if
i
==
0
:
name
=
prefix
name
=
prefix
# no prefix for tower0
else
:
name
=
'tower
%
s/
%
s'
%
(
i
,
prefix
)
if
name
in
local_var_by_name
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment