Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
091568ec
Commit
091568ec
authored
Dec 06, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix DistributedTrainer (fix #505)
parent
2be64ce0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
22 deletions
+9
-22
tensorpack/graph_builder/predict.py
tensorpack/graph_builder/predict.py
+0
-10
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+3
-1
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+3
-2
tensorpack/train/interface.py
tensorpack/train/interface.py
+1
-7
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+2
-2
No files found.
tensorpack/graph_builder/predict.py
View file @
091568ec
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
# File: predict.py
# File: predict.py
import
tensorflow
as
tf
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
...
@@ -29,14 +28,6 @@ class SimplePredictBuilder(GraphBuilder):
...
@@ -29,14 +28,6 @@ class SimplePredictBuilder(GraphBuilder):
device
=
'/gpu:{}'
.
format
(
device
)
if
device
>=
0
else
'/cpu:0'
device
=
'/gpu:{}'
.
format
(
device
)
if
device
>=
0
else
'/cpu:0'
self
.
_device
=
device
self
.
_device
=
device
@
contextmanager
def
_maybe_open_vs
(
self
):
if
len
(
self
.
_vs_name
):
with
tf
.
variable_scope
(
self
.
_vs_name
):
yield
else
:
yield
def
build
(
self
,
input
,
tower_fn
):
def
build
(
self
,
input
,
tower_fn
):
"""
"""
Args:
Args:
...
@@ -51,7 +42,6 @@ class SimplePredictBuilder(GraphBuilder):
...
@@ -51,7 +42,6 @@ class SimplePredictBuilder(GraphBuilder):
self
.
_ns_name
,
self
.
_device
))
self
.
_ns_name
,
self
.
_device
))
with
tf
.
device
(
self
.
_device
),
\
with
tf
.
device
(
self
.
_device
),
\
self
.
_maybe_open_vs
(),
\
TowerContext
(
TowerContext
(
self
.
_ns_name
,
is_training
=
False
,
vs_name
=
self
.
_vs_name
):
self
.
_ns_name
,
is_training
=
False
,
vs_name
=
self
.
_vs_name
):
inputs
=
input
.
get_input_tensors
()
inputs
=
input
.
get_input_tensors
()
...
...
tensorpack/input_source/input_source.py
View file @
091568ec
...
@@ -34,7 +34,9 @@ __all__ = ['PlaceholderInput', 'FeedInput',
...
@@ -34,7 +34,9 @@ __all__ = ['PlaceholderInput', 'FeedInput',
def
_get_reset_callback
(
df
):
def
_get_reset_callback
(
df
):
return
CallbackFactory
(
setup_graph
=
lambda
_
:
df
.
reset_state
())
ret
=
CallbackFactory
(
setup_graph
=
lambda
_
:
df
.
reset_state
())
ret
.
chief_only
=
False
return
ret
class
PlaceholderInput
(
InputSource
):
class
PlaceholderInput
(
InputSource
):
...
...
tensorpack/tfutils/tower.py
View file @
091568ec
...
@@ -124,8 +124,9 @@ class TowerContext(object):
...
@@ -124,8 +124,9 @@ class TowerContext(object):
global
_CurrentTowerContext
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
"Cannot nest TowerContext!"
assert
_CurrentTowerContext
is
None
,
"Cannot nest TowerContext!"
_CurrentTowerContext
=
self
_CurrentTowerContext
=
self
curr_vs
=
tf
.
get_variable_scope
()
if
self
.
is_training
:
assert
curr_vs
.
name
==
''
,
"Cannot nest TowerContext with an existing variable scope!"
curr_vs
=
tf
.
get_variable_scope
()
assert
curr_vs
.
name
==
''
,
"In training, cannot nest TowerContext with an existing variable scope!"
self
.
_ctxs
=
self
.
_get_scopes
()
self
.
_ctxs
=
self
.
_get_scopes
()
self
.
_ctxs
.
append
(
self
.
_collection_guard
)
self
.
_ctxs
.
append
(
self
.
_collection_guard
)
...
...
tensorpack/train/interface.py
View file @
091568ec
...
@@ -9,7 +9,7 @@ from ..input_source import (
...
@@ -9,7 +9,7 @@ from ..input_source import (
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
.tower
import
SingleCostTrainer
from
.tower
import
SingleCostTrainer
from
.trainers
import
SimpleTrainer
,
DistributedTrainerReplicated
from
.trainers
import
SimpleTrainer
__all__
=
[
'launch_train_with_config'
,
'apply_default_prefetch'
]
__all__
=
[
'launch_train_with_config'
,
'apply_default_prefetch'
]
...
@@ -77,12 +77,6 @@ def launch_train_with_config(config, trainer):
...
@@ -77,12 +77,6 @@ def launch_train_with_config(config, trainer):
input
=
config
.
data
or
config
.
dataflow
input
=
config
.
data
or
config
.
dataflow
input
=
apply_default_prefetch
(
input
,
trainer
,
config
.
tower
)
input
=
apply_default_prefetch
(
input
,
trainer
,
config
.
tower
)
if
isinstance
(
trainer
,
DistributedTrainerReplicated
)
and
\
config
.
session_config
is
not
None
:
raise
ValueError
(
"Cannot set session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server."
)
trainer
.
setup_graph
(
trainer
.
setup_graph
(
inputs_desc
,
input
,
inputs_desc
,
input
,
model
.
_build_graph_get_cost
,
model
.
get_optimizer
)
model
.
_build_graph_get_cost
,
model
.
get_optimizer
)
...
...
tensorpack/train/trainers.py
View file @
091568ec
...
@@ -165,7 +165,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
...
@@ -165,7 +165,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
logger
.
info
(
"Running ps {}"
.
format
(
self
.
server
.
server_def
.
task_index
))
logger
.
info
(
"Running ps {}"
.
format
(
self
.
server
.
server_def
.
task_index
))
logger
.
info
(
"Kill me with 'kill {}'"
.
format
(
os
.
getpid
()))
logger
.
info
(
"Kill me with 'kill {}'"
.
format
(
os
.
getpid
()))
self
.
server
.
join
()
# this function will never return tensorflow#4713
self
.
server
.
join
()
# this function will never return tensorflow#4713
raise
RuntimeError
(
"This is a bug
in tensorpack
. Server.join() for ps should never return!"
)
raise
RuntimeError
(
"This is a bug. Server.join() for ps should never return!"
)
with
override_to_local_variable
():
with
override_to_local_variable
():
get_global_step_var
()
# gs should be local
get_global_step_var
()
# gs should be local
...
@@ -204,7 +204,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
...
@@ -204,7 +204,7 @@ class DistributedTrainerReplicated(SingleCostTrainer):
"You are not allowed to set session_creator or session_config for distributed training! "
"You are not allowed to set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server."
)
"To use a custom session config, pass it to tf.train.Server."
)
super
(
DistributedTrainerReplicated
,
self
)
.
initialize
(
super
(
DistributedTrainerReplicated
,
self
)
.
initialize
(
get_distributed_session_creator
(),
session_init
)
get_distributed_session_creator
(
self
.
server
),
session_init
)
@
property
@
property
def
_main_tower_vs_name
(
self
):
def
_main_tower_vs_name
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment