Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2cba9434
Commit
2cba9434
authored
Aug 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix distributed trainer
parent
7780c64b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
10 deletions
+11
-10
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+8
-7
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+2
-2
No files found.
tensorpack/train/distributed.py
View file @
2cba9434
...
...
@@ -79,7 +79,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
self
.
sync_queue_devices
=
[
'/job:ps/task:
%
s/cpu:0'
%
i
for
i
in
range
(
self
.
num_ps
)]
self
.
sync_queue_counter
=
0
super
(
Distributed
ReplicatedTrainer
,
self
)
.
__init__
(
config
)
super
(
Distributed
TrainerReplicated
,
self
)
.
__init__
(
config
)
@
staticmethod
def
_average_grads
(
tower_grads
,
devices
):
...
...
@@ -187,8 +187,9 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
with
tf
.
device
(
self
.
param_server_device
):
gs
=
get_global_step_var
()
assert
gs
.
device
,
gs
.
device
# do this before super.setup because input_source my need global step
super
(
DistributedReplicatedTrainer
,
self
)
.
_setup
()
# do this before inputsource.setup because input_source my need global step
cbs
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
config
.
callbacks
.
extend
(
cbs
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
...
...
@@ -199,15 +200,15 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
lambda
:
MultiGPUTrainerBase
.
_build_graph_get_grads
(
self
.
model
,
self
.
_input_source
),
devices
=
self
.
raw_devices
,
vs_name
s
=
[
True
]
*
self
.
config
.
nr_tower
)
# open vs at each tower
use_v
s
=
[
True
]
*
self
.
config
.
nr_tower
)
# open vs at each tower
MultiGPUTrainerBase
.
_check_grad_list
(
grad_list
)
avg_grads
=
Distributed
ReplicatedTrainer
.
_average_grads
(
grad_list
,
self
.
raw_devices
)
avg_grads
=
Distributed
TrainerReplicated
.
_average_grads
(
grad_list
,
self
.
raw_devices
)
with
tf
.
device
(
self
.
param_server_device
):
ps_var_grads
=
Distributed
ReplicatedTrainer
.
_apply_shadow_vars
(
avg_grads
)
ps_var_grads
=
Distributed
TrainerReplicated
.
_apply_shadow_vars
(
avg_grads
)
var_update_ops
=
self
.
_apply_gradients_and_copy
(
grad_list
,
ps_var_grads
)
self
.
_shadow_vars
=
[
v
for
(
_
,
v
)
in
ps_var_grads
]
self
.
_shadow_model_vars
=
Distributed
ReplicatedTrainer
.
_shadow_model_variables
(
self
.
_shadow_vars
)
self
.
_shadow_model_vars
=
Distributed
TrainerReplicated
.
_shadow_model_variables
(
self
.
_shadow_vars
)
# TODO add options to synchronize less
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
...
...
tensorpack/train/feedfree.py
View file @
2cba9434
...
...
@@ -28,7 +28,7 @@ class FeedfreeTrainerBase(Trainer):
def
_setup
(
self
):
assert
isinstance
(
self
.
_input_source
,
FeedfreeInput
),
type
(
self
.
_input_source
)
cbs
=
self
.
_
setup_
input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
cbs
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
config
.
callbacks
.
extend
(
cbs
)
...
...
tensorpack/train/multigpu.py
View file @
2cba9434
...
...
@@ -16,7 +16,7 @@ from ..tfutils.gradproc import ScaleGradient
from
..callbacks.graph
import
RunOp
from
..graph_builder.input_source
import
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
from
.
feedfree
import
FeedfreeTrainerBase
from
.
base
import
Trainer
__all__
=
[
'MultiGPUTrainerBase'
,
'LeastLoadedDeviceSetter'
,
'SyncMultiGPUTrainerReplicated'
,
...
...
@@ -45,7 +45,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
config
.
data
=
StagingInputWrapper
(
config
.
data
,
devices
)
class
MultiGPUTrainerBase
(
FeedfreeTrainerBase
):
class
MultiGPUTrainerBase
(
Trainer
):
""" Base class for multi-gpu training"""
@
staticmethod
def
build_on_multi_tower
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment