Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f1e3b3ae
Commit
f1e3b3ae
authored
Jun 01, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use sessioncreator API for supervisor
parent
c04c1ef8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
48 deletions
+39
-48
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+39
-48
No files found.
tensorpack/train/distributed.py
View file @
f1e3b3ae
...
@@ -4,19 +4,14 @@
...
@@ -4,19 +4,14 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
six.moves
import
range
from
six.moves
import
range
import
weakref
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
from
..utils
import
logger
from
..utils
import
logger
from
.input_source
import
StagingInputWrapper
,
FeedfreeInput
from
.input_source
import
StagingInputWrapper
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.multigpu
import
MultiGPUTrainerBase
from
.multigpu
import
MultiGPUTrainerBase
from
..tfutils.model_utils
import
describe_model
from
..callbacks
import
RunOp
from
..callbacks
import
Callbacks
,
ProgressBar
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.common
import
get_global_step_var
,
get_op_tensor_name
from
..tfutils.common
import
get_default_sess_config
,
get_global_step_var
,
get_op_tensor_name
from
..callbacks.monitor
import
Monitors
__all__
=
[
'DistributedReplicatedTrainer'
]
__all__
=
[
'DistributedReplicatedTrainer'
]
...
@@ -160,6 +155,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
...
@@ -160,6 +155,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return
# TODO exit and skip mainloop how?
return
# TODO exit and skip mainloop how?
super
(
DistributedReplicatedTrainer
,
self
)
.
_setup
()
super
(
DistributedReplicatedTrainer
,
self
)
.
_setup
()
with
tf
.
device
(
self
.
param_server_device
):
get_global_step_var
()
self
.
model
.
get_optimizer
()
# TODO in global scope, not local
with
tf
.
variable_scope
(
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
tf
.
get_variable_scope
(),
custom_getter
=
OverrideToLocalVariableIfNotPsVar
()):
custom_getter
=
OverrideToLocalVariableIfNotPsVar
()):
...
@@ -177,44 +176,36 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
...
@@ -177,44 +176,36 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
'sync_queues_step_end'
,
[
main_fetch
])
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
'sync_queues_step_end'
,
[
main_fetch
])
self
.
post_init_op
=
self
.
get_post_init_ops
()
self
.
register_callback
(
RunOp
(
self
.
get_post_init_ops
,
run_before
=
True
,
run_as_trigger
=
False
))
def
setup
(
self
):
with
tf
.
device
(
self
.
param_server_device
):
self
.
_set_session_creator
()
gs
=
get_global_step_var
()
opt
=
self
.
model
.
get_optimizer
()
# in global scope, not local
def
_set_session_creator
(
self
):
self
.
_setup
()
old_sess_creator
=
self
.
config
.
session_creator
if
not
isinstance
(
old_sess_creator
,
NewSessionCreator
)
\
self
.
monitors
=
Monitors
(
self
.
monitors
)
or
self
.
config
.
session_config
is
not
None
:
self
.
register_callback
(
self
.
monitors
)
raise
ValueError
(
describe_model
()
"Cannot set session_creator or session_config for distributed training! "
logger
.
info
(
"Setup callbacks graph ..."
)
"To use a custom session config, pass it to the tf.train.Server constructor."
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
class
SupervisedSessionCreator
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
is_chief
,
target
):
logger
.
info
(
"Finalize the graph, create the session ..."
)
self
.
is_chief
=
is_chief
self
.
target
=
target
self
.
sv
=
tf
.
train
.
Supervisor
(
is_chief
=
self
.
is_chief
,
def
create_session
(
self
):
logdir
=
None
,
# supervisor will finalize the graph..
saver
=
None
,
self
.
sv
=
tf
.
train
.
Supervisor
(
global_step
=
gs
,
is_chief
=
self
.
is_chief
,
summary_op
=
None
,
logdir
=
None
,
saver
=
None
,
save_model_secs
=
0
,
global_step
=
get_global_step_var
(),
summary_writer
=
None
)
summary_op
=
None
,
save_model_secs
=
0
,
summary_writer
=
None
)
sess
=
self
.
sv
.
prepare_or_wait_for_session
(
return
self
.
sv
.
prepare_or_wait_for_session
(
master
=
self
.
server
.
target
,
master
=
self
.
target
,
start_standard_services
=
False
)
start_standard_services
=
False
)
self
.
config
.
session_creator
=
SupervisedSessionCreator
(
self
.
sess
=
sess
self
.
is_chief
,
self
.
server
.
target
)
logger
.
info
(
"Running post init op..."
)
sess
.
run
(
self
.
post_init_op
)
logger
.
info
(
"Post init op finished."
)
self
.
_monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
None
)
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
hooked_sess
=
HookedSession
(
self
.
sess
,
hooks
)
def
add_sync_queues_and_barrier
(
self
,
name_prefix
,
enqueue_after_list
):
def
add_sync_queues_and_barrier
(
self
,
name_prefix
,
enqueue_after_list
):
"""Adds ops to enqueue on all worker queues.
"""Adds ops to enqueue on all worker queues.
...
@@ -272,5 +263,5 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
...
@@ -272,5 +263,5 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
copy_to
=
local_var_by_name
[
name
]
copy_to
=
local_var_by_name
[
name
]
post_init_ops
.
append
(
copy_to
.
assign
(
v
.
read_value
()))
post_init_ops
.
append
(
copy_to
.
assign
(
v
.
read_value
()))
else
:
else
:
logger
.
warn
(
"Global var
{} doesn't match
local var"
.
format
(
v
.
name
))
logger
.
warn
(
"Global var
able {} doesn't match a corresponding
local var"
.
format
(
v
.
name
))
return
tf
.
group
(
*
post_init_ops
,
name
=
'post_init_ops'
)
return
tf
.
group
(
*
post_init_ops
,
name
=
'post_init_ops'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment