Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
12cd6e6c
Commit
12cd6e6c
authored
Oct 18, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Trainerv2] fix interface for distributed trainer.
parent
a673974c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
24 additions
and
26 deletions
+24
-26
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+6
-0
tensorpack/train/config.py
tensorpack/train/config.py
+0
-2
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+1
-1
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+5
-2
tensorpack/trainv2/trainers.py
tensorpack/trainv2/trainers.py
+12
-21
No files found.
tensorpack/tfutils/sesscreate.py
View file @
12cd6e6c
...
...
@@ -24,7 +24,13 @@ class NewSessionCreator(tf.train.SessionCreator):
"""
self
.
target
=
target
if
config
is
None
:
# distributd trainer doesn't support user-provided config
# we set this attribute so that they can check
self
.
user_provided_config
=
False
config
=
get_default_sess_config
()
else
:
self
.
user_provided_config
=
True
self
.
config
=
config
self
.
graph
=
graph
...
...
tensorpack/train/config.py
View file @
12cd6e6c
...
...
@@ -101,8 +101,6 @@ class TrainConfig(object):
else
:
self
.
session_creator
=
session_creator
assert
session_config
is
None
,
"Cannot set both session_creator and session_config!"
# only used by DistributedTrainer for assertion!
self
.
session_config
=
session_config
if
steps_per_epoch
is
None
:
try
:
...
...
tensorpack/train/distributed.py
View file @
12cd6e6c
...
...
@@ -85,7 +85,7 @@ class DistributedTrainerReplicated(Trainer):
def
_set_session_creator
(
self
):
old_sess_creator
=
self
.
_config
.
session_creator
if
not
isinstance
(
old_sess_creator
,
NewSessionCreator
)
\
or
self
.
_config
.
session_config
is
not
None
:
or
old_sess_creator
.
user_provided_config
:
raise
ValueError
(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server."
)
...
...
tensorpack/trainv2/base.py
View file @
12cd6e6c
...
...
@@ -231,8 +231,7 @@ class SingleCostTrainer(Trainer):
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
assert
not
input
.
setup_done
()
input_callbacks
=
input
.
setup
(
inputs_desc
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
return
self
.
_internal_callbacks
...
...
@@ -240,3 +239,7 @@ class SingleCostTrainer(Trainer):
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
pass
def
_setup_input
(
self
,
inputs_desc
,
input
):
assert
not
input
.
setup_done
()
return
input
.
setup
(
inputs_desc
)
tensorpack/trainv2/trainers.py
View file @
12cd6e6c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: trainers.py
import
os
from
..callbacks.graph
import
RunOp
...
...
@@ -17,7 +18,7 @@ from ..tfutils import get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..input_source
import
QueueInput
from
.base
import
Trainer
,
SingleCostTrainer
from
.base
import
SingleCostTrainer
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
...
...
@@ -32,8 +33,7 @@ class SimpleTrainer(SingleCostTrainer):
Single-GPU single-cost single-tower trainer.
"""
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
self
.
train_op
=
SimpleBuilder
()
.
build
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
train_op
=
SimpleBuilder
()
.
build
(
input
,
get_cost_fn
,
get_opt_fn
)
return
[]
...
...
@@ -126,17 +126,13 @@ class DistributedTrainerReplicated(SingleCostTrainer):
self
.
is_chief
=
False
logger
.
info
(
"Distributed training on cluster:
\n
"
+
str
(
server
.
server_def
.
cluster
))
def
train
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
,
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
def
_setup_input
(
self
,
inputs_desc
,
input
):
if
self
.
job_name
==
'ps'
:
# ps shouldn't setup input either
logger
.
info
(
"Running ps {}"
.
format
(
self
.
server
.
server_def
.
task_index
))
logger
.
info
(
"Kill me with 'kill {}'"
.
format
(
os
.
getpid
()))
self
.
server
.
join
()
# this will never return tensorflow#4713
r
eturn
self
.
server
.
join
()
# this
function
will never return tensorflow#4713
r
aise
RuntimeError
(
"This is a bug in tensorpack. Server.join() for ps should never return!"
)
with
override_to_local_variable
():
get_global_step_var
()
# gs should be local
...
...
@@ -144,14 +140,8 @@ class DistributedTrainerReplicated(SingleCostTrainer):
# TODO This is not good because we don't know from here
# whether something should be global or local. We now assume
# they should be local.
input_callbacks
=
input
.
setup
(
inputs_desc
)
train_callbacks
=
self
.
setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
Trainer
.
train
(
self
,
callbacks
+
input_callbacks
+
train_callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
)
assert
not
input
.
setup_done
()
return
input
.
setup
(
inputs_desc
)
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
self
.
train_op
,
initial_sync_op
,
model_sync_op
=
self
.
_builder
.
build
(
...
...
@@ -174,9 +164,10 @@ class DistributedTrainerReplicated(SingleCostTrainer):
return
callbacks
def
initialize
(
self
,
session_creator
,
session_init
):
if
not
isinstance
(
session_creator
,
NewSessionCreator
):
if
not
isinstance
(
session_creator
,
NewSessionCreator
)
or
\
session_creator
.
user_provided_config
:
raise
ValueError
(
"Cannot set session_creator for distributed training! "
"Cannot set session_creator
or session_config
for distributed training! "
"To use a custom session config, pass it to tf.train.Server."
)
super
(
DistributedTrainerReplicated
,
self
)
.
initialize
(
get_distributed_session_creator
(),
session_init
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment