Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
afa11399
Commit
afa11399
authored
Feb 25, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix distributed trainer (fix #671)
parent
5626e04d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
11 deletions
+33
-11
tensorpack/callbacks/monitor.py
tensorpack/callbacks/monitor.py
+7
-1
tensorpack/graph_builder/distributed.py
tensorpack/graph_builder/distributed.py
+15
-6
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+9
-2
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+1
-1
No files found.
tensorpack/callbacks/monitor.py
View file @
afa11399
...
...
@@ -104,8 +104,11 @@ class Monitors(Callback):
You should use `trainer.monitors` for logging and it will dispatch your
logs to each sub-monitor.
"""
_chief_only
=
False
def
__init__
(
self
,
monitors
):
self
.
_scalar_history
=
ScalarHistory
()
self
.
_scalar_history
=
ScalarHistory
()
.
set_chief_only
(
False
)
self
.
_monitors
=
monitors
+
[
self
.
_scalar_history
]
for
m
in
self
.
_monitors
:
assert
isinstance
(
m
,
TrainingMonitor
),
m
...
...
@@ -325,6 +328,9 @@ class ScalarPrinter(TrainingMonitor):
"""
Print scalar data into terminal.
"""
_chief_only
=
False
def
__init__
(
self
,
enable_step
=
False
,
enable_epoch
=
True
,
whitelist
=
None
,
blacklist
=
None
):
"""
...
...
tensorpack/graph_builder/distributed.py
View file @
afa11399
...
...
@@ -6,6 +6,7 @@ import tensorflow as tf
import
re
from
six.moves
import
range
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..tfutils.common
import
get_op_tensor_name
,
get_global_step_var
...
...
@@ -230,19 +231,26 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
Returns:
list of (shadow_model_var, local_model_var) used for syncing.
"""
G
=
tf
.
get_default_graph
()
curr_shadow_vars
=
set
([
v
.
name
for
v
in
shadow_vars
])
model_vars
=
tf
.
model_variables
()
shadow_model_vars
=
[]
for
v
in
model_vars
:
assert
v
.
name
.
startswith
(
'tower'
),
"Found some MODEL_VARIABLES created outside of the
model
!"
stripped_
name
=
get_op_tensor_name
(
re
.
sub
(
'tower[0-9]+/'
,
''
,
v
.
name
))[
0
]
if
stripped_name
in
curr_shadow_vars
:
assert
v
.
name
.
startswith
(
'tower'
),
"Found some MODEL_VARIABLES created outside of the
tower function
!"
stripped_
op_name
,
stripped_var_name
=
get_op_tensor_name
(
re
.
sub
(
'^tower[0-9]+/'
,
''
,
v
.
name
))
if
stripped_
op_
name
in
curr_shadow_vars
:
continue
new_v
=
tf
.
get_variable
(
stripped_name
,
dtype
=
v
.
dtype
.
base_dtype
,
try
:
G
.
get_tensor_by_name
(
stripped_var_name
)
logger
.
warn
(
"Model Variable {} also appears in other collections."
.
format
(
stripped_var_name
))
continue
except
KeyError
:
pass
new_v
=
tf
.
get_variable
(
stripped_op_name
,
dtype
=
v
.
dtype
.
base_dtype
,
initializer
=
v
.
initial_value
,
trainable
=
False
)
curr_shadow_vars
.
add
(
stripped_name
)
# avoid duplicated shadow_model_vars
curr_shadow_vars
.
add
(
stripped_
op_
name
)
# avoid duplicated shadow_model_vars
shadow_vars
.
append
(
new_v
)
shadow_model_vars
.
append
((
new_v
,
v
))
# only need to sync model_var from one tower
return
shadow_model_vars
...
...
@@ -279,7 +287,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
use_vs
=
[
True
]
*
len
(
self
.
towers
))
# open vs at each tower
DataParallelBuilder
.
_check_grad_list
(
grad_list
)
avg_grads
=
average_grads
(
grad_list
,
devices
=
self
.
raw_devices
)
avg_grads
=
average_grads
(
grad_list
,
colocation
=
False
,
devices
=
self
.
raw_devices
)
with
tf
.
device
(
self
.
param_server_device
):
ps_var_grads
=
DistributedReplicatedBuilder
.
_apply_shadow_vars
(
avg_grads
)
var_update_ops
=
self
.
_apply_gradients_and_copy
(
...
...
tensorpack/graph_builder/utils.py
View file @
afa11399
...
...
@@ -56,7 +56,7 @@ def override_to_local_variable(enable=True):
ns
=
orig_vs
.
original_name_scope
with
tf
.
variable_scope
(
orig_vs
,
custom_getter
=
custom_getter
):
with
tf
.
name_scope
(
ns
+
'/'
):
with
tf
.
name_scope
(
ns
+
'/'
if
ns
else
''
):
yield
else
:
yield
...
...
tensorpack/train/base.py
View file @
afa11399
...
...
@@ -143,6 +143,9 @@ class Trainer(object):
Args:
cb (Callback or [Callback]): a callback or a list of callbacks
Returns:
succeed or not
"""
if
isinstance
(
cb
,
(
list
,
tuple
)):
for
x
in
cb
:
...
...
@@ -153,8 +156,10 @@ class Trainer(object):
"Cannot register more callbacks after trainer was setup!"
if
not
self
.
is_chief
and
cb
.
chief_only
:
logger
.
warn
(
"Callback {} is chief-only, skipped."
.
format
(
str
(
cb
)))
return
False
else
:
self
.
_callbacks
.
append
(
cb
)
return
True
register_callback
=
_register_callback
...
...
@@ -188,9 +193,11 @@ class Trainer(object):
self
.
register_callback
(
cb
)
for
cb
in
self
.
_callbacks
:
assert
not
isinstance
(
cb
,
TrainingMonitor
),
"Monitor cannot be pre-registered for now!"
registered_monitors
=
[]
for
m
in
monitors
:
self
.
register_callback
(
m
)
self
.
monitors
=
Monitors
(
monitors
)
if
self
.
register_callback
(
m
):
registered_monitors
.
append
(
m
)
self
.
monitors
=
Monitors
(
registered_monitors
)
self
.
register_callback
(
self
.
monitors
)
# monitors is also a callback
# some final operations that might modify the graph
...
...
tensorpack/train/trainers.py
View file @
afa11399
...
...
@@ -214,7 +214,7 @@ class DistributedTrainerParameterServer(DistributedTrainerBase):
return
[]
class
DistributedTrainerReplicated
(
SingleCostTrainer
):
class
DistributedTrainerReplicated
(
DistributedTrainerBase
):
__doc__
=
DistributedReplicatedBuilder
.
__doc__
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment