Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
621a1bbd
Commit
621a1bbd
authored
Sep 27, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix wrong variable collection in distributed training (#431)
parent
7632ca9f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
1 deletion
+13
-1
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+2
-0
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+11
-1
No files found.
tensorpack/tfutils/summary.py
View file @
621a1bbd
...
...
@@ -33,6 +33,8 @@ def _get_cached_vs(name):
@
contextmanager
def
_enter_vs_reuse_ns
(
name
):
vs
=
_get_cached_vs
(
name
)
# XXX Not good to enter the cached vs directly, because this will clean-up custom getter
# with tf.variable_scope(name, reuse=tf.AUTO_REUSE): # available in 1.4 only
with
tf
.
variable_scope
(
vs
):
with
tf
.
name_scope
(
vs
.
original_name_scope
):
yield
vs
...
...
tensorpack/train/distributed.py
View file @
621a1bbd
...
...
@@ -13,6 +13,7 @@ from ..tfutils.sesscreate import NewSessionCreator
from
..tfutils.common
import
get_global_step_var
,
get_op_tensor_name
from
.multigpu
import
MultiGPUTrainerBase
from
.utility
import
override_to_local_variable
__all__
=
[
'DistributedTrainerReplicated'
]
...
...
@@ -180,6 +181,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
opt
=
self
.
model
.
get_optimizer
()
var_update_ops
=
[]
for
vid
,
(
g
,
v
)
in
enumerate
(
ps_var_grads
):
# TODO do we put momentum variables into local or global?
apply_gradient_op
=
opt
.
apply_gradients
([(
g
,
v
)])
barrier
=
self
.
_add_sync_queues_and_barrier
(
'param_update_barrier_{}'
.
format
(
vid
),
[
apply_gradient_op
])
...
...
@@ -201,6 +203,9 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
gs
=
get_global_step_var
()
assert
gs
.
device
,
gs
.
device
# do this before inputsource.setup because input_source my need global step
with
override_to_local_variable
():
# input source may create variable (queue size summary)
cbs
=
self
.
_input_source
.
setup
(
self
.
model
.
get_inputs_desc
())
self
.
config
.
callbacks
.
extend
(
cbs
)
...
...
@@ -258,6 +263,11 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
local_init_op
=
local_init_op
,
ready_op
=
ready_op
,
graph
=
tf
.
get_default_graph
())
# to debug wrong variable collection
# print("GLOBAL:")
# print(tf.global_variables())
# print("LOCAL:")
# print(tf.local_variables())
def
_create_session
():
if
self
.
is_chief
:
return
sm
.
prepare_session
(
master
=
self
.
server
.
target
,
init_op
=
init_op
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment