Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9fd5cb9f
Commit
9fd5cb9f
authored
Jun 01, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix batchrenorm, simplify dist-trainer code
parent
86d1b2e5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
35 deletions
+49
-35
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+49
-35
No files found.
tensorpack/train/distributed.py
View file @
9fd5cb9f
...
...
@@ -15,7 +15,7 @@ from .multigpu import MultiGPUTrainerBase
from
..tfutils.model_utils
import
describe_model
from
..callbacks
import
Callbacks
,
ProgressBar
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.common
import
get_default_sess_config
,
get_global_step_var
from
..tfutils.common
import
get_default_sess_config
,
get_global_step_var
,
get_op_tensor_name
from
..callbacks.monitor
import
Monitors
__all__
=
[
'DistributedReplicatedTrainer'
]
...
...
@@ -50,7 +50,6 @@ class OverrideToLocalVariableIfNotPsVar(object):
class
DistributedReplicatedTrainer
(
SingleCostFeedfreeTrainer
):
def
__init__
(
self
,
config
,
job_name
,
task_index
,
cluster
):
assert
job_name
in
[
'ps'
,
'worker'
],
job_name
self
.
config
=
config
self
.
job_name
=
job_name
self
.
task_index
=
task_index
self
.
cluster
=
cluster
...
...
@@ -61,14 +60,16 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
worker_prefix
=
'/job:worker/task:
%
s'
%
self
.
task_index
self
.
param_server_device
=
tf
.
train
.
replica_device_setter
(
worker_device
=
worker_prefix
+
'/cpu:0'
,
cluster
=
self
.
cluster
)
# This device on which the queues for managing synchronization between
# servers should be stored.
num_ps
=
self
.
cluster
.
num_tasks
(
'ps'
)
self
.
num_ps
=
self
.
cluster
.
num_tasks
(
'ps'
)
self
.
num_worker
=
self
.
cluster
.
num_tasks
(
'worker'
)
self
.
cpu_device
=
'
%
s/cpu:0'
%
worker_prefix
self
.
nr_gpu
=
config
.
nr_tower
self
.
cpu_device
=
'
%
s/cpu:0'
%
worker_prefix
self
.
raw_devices
=
[
'
%
s/
%
s:
%
i'
%
(
worker_prefix
,
'gpu'
,
i
)
for
i
in
range
(
self
.
nr_gpu
)]
self
.
sync_queue_devices
=
[
'/job:ps/task:
%
s/cpu:0'
%
i
for
i
in
range
(
num_ps
)]
# This device on which the queues for managing synchronization between
# servers should be stored.
self
.
sync_queue_devices
=
[
'/job:ps/task:
%
s/cpu:0'
%
i
for
i
in
range
(
self
.
num_ps
)]
self
.
sync_queue_counter
=
0
if
self
.
nr_gpu
>
1
:
...
...
@@ -78,6 +79,31 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
if
not
isinstance
(
self
.
_input_source
,
StagingInputWrapper
):
self
.
_input_source
=
StagingInputWrapper
(
self
.
_input_source
,
self
.
raw_devices
)
@
staticmethod
def
_average_grads
(
tower_grads
,
devices
):
"""
Average grad with round-robin device selection.
Args:
tower_grads: Ngpu x Nvar x 2
"""
nr_device
=
len
(
devices
)
if
nr_device
==
1
:
return
tower_grads
[
0
]
new_tower_grads
=
[]
with
tf
.
name_scope
(
'AvgGrad'
):
for
i
,
grad_and_vars
in
enumerate
(
zip
(
*
grad_list
)):
# Ngpu * 2
with
tf
.
device
(
devices
[
i
%
nr_device
]):
v
=
grad_and_vars
[
0
][
1
]
# average gradient
all_grads
=
[
g
for
(
g
,
_
)
in
grad_and_vars
]
if
not
MultiGPUTrainerBase
.
check_none_grads
(
v
.
op
.
name
,
all_grads
):
continue
grad
=
tf
.
multiply
(
tf
.
add_n
(
all_grads
),
1.0
/
nr_device
)
new_tower_grads
.
append
((
grad
,
v
))
return
new_tower_grads
def
_setup
(
self
):
conf
=
get_default_sess_config
()
self
.
server
=
tf
.
train
.
Server
(
...
...
@@ -101,35 +127,23 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
devices
=
self
.
raw_devices
,
var_strategy
=
'replicated'
)
# (g, v) to be applied, where v is global (ps vars)
new_tower_grads
=
[]
for
i
,
grad_and_vars
in
enumerate
(
zip
(
*
grad_list
)):
# Ngpu * 2
with
tf
.
device
(
self
.
raw_devices
[
i
%
self
.
nr_gpu
]):
v
=
grad_and_vars
[
0
][
1
]
if
self
.
nr_gpu
>
1
:
# average gradient
all_grads
=
[
g
for
(
g
,
_
)
in
grad_and_vars
]
if
not
MultiGPUTrainerBase
.
check_none_grads
(
v
.
op
.
name
,
all_grads
):
continue
grad
=
tf
.
multiply
(
tf
.
add_n
(
all_grads
),
1.0
/
self
.
nr_gpu
)
else
:
grad
=
grad_and_vars
[
0
][
0
]
avg_grads
=
DistributedReplicatedTrainer
.
_average_grads
(
grad_list
,
self
.
raw_devices
)
# Nvar * 2
ps_var_grads
=
[]
for
i
,
(
grad
,
var
)
in
enumerate
(
avg_grads
):
with
tf
.
device
(
self
.
param_server_device
):
my_name
=
PS_SHADOW_VAR_PREFIX
+
'/'
+
v
.
name
if
my_name
.
endswith
(
':0'
):
my_name
=
my_name
[:
-
2
]
new_v
=
tf
.
get_variable
(
my_name
,
dtype
=
v
.
dtype
.
base_dtype
,
initializer
=
v
.
initial_value
,
my_name
=
PS_SHADOW_VAR_PREFIX
+
'/'
+
var
.
name
my_name
=
get_op_tensor_name
(
my_name
)[
0
]
new_v
=
tf
.
get_variable
(
my_name
,
dtype
=
var
.
dtype
.
base_dtype
,
initializer
=
var
.
initial_value
,
trainable
=
True
)
new_tower_grads
.
append
((
grad
,
new_v
))
# (g, v) to be applied, where v is global (ps vars)
ps_var_grads
.
append
((
grad
,
new_v
))
# apply gradients TODO do this for each variable separately?
var_update_ops
=
[]
with
tf
.
device
(
self
.
param_server_device
):
for
vid
,
(
g
,
v
)
in
enumerate
(
new_towe
r_grads
):
for
vid
,
(
g
,
v
)
in
enumerate
(
ps_va
r_grads
):
apply_gradient_op
=
opt
.
apply_gradients
([(
g
,
v
)])
barrier
=
self
.
add_sync_queues_and_barrier
(
'param_update_barrier_{}'
.
format
(
vid
),
[
apply_gradient_op
])
...
...
@@ -141,8 +155,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
var_update_ops
.
append
(
grad_list
[
towerid
][
vid
][
1
]
.
assign
(
updated_value
))
self
.
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
self
.
train_op
=
self
.
main_fetch
#
self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [self.main_fetch])
#
self.train_op = self.main_fetch
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
'sync_queues_step_end'
,
[
self
.
main_fetch
])
self
.
post_init_op
=
self
.
get_post_init_ops
()
def
setup
(
self
):
...
...
@@ -199,12 +213,12 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
an op that should be used as control dependency before starting next step.
"""
self
.
sync_queue_counter
+=
1
num_workers
=
self
.
cluster
.
num_tasks
(
'worker'
)
self
.
num_worker
=
self
.
cluster
.
num_tasks
(
'worker'
)
with
tf
.
device
(
self
.
sync_queue_devices
[
self
.
sync_queue_counter
%
len
(
self
.
sync_queue_devices
)]):
sync_queues
=
[
tf
.
FIFOQueue
(
num_workers
,
[
tf
.
bool
],
shapes
=
[[]],
tf
.
FIFOQueue
(
self
.
num_worker
,
[
tf
.
bool
],
shapes
=
[[]],
shared_name
=
'
%
s
%
s'
%
(
name_prefix
,
i
))
for
i
in
range
(
num_workers
)]
for
i
in
range
(
self
.
num_worker
)]
queue_ops
=
[]
# For each other worker, add an entry in a queue, signaling that it can
# finish this step.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment