Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a53da5ab
Commit
a53da5ab
authored
Jun 04, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[distributed] Can save and sync MODEL_VARIABLES
parent
5b18f8be
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
21 deletions
+69
-21
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+3
-2
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+61
-13
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+5
-6
No files found.
tensorpack/tfutils/varmanip.py
View file @
a53da5ab
...
...
@@ -160,7 +160,7 @@ def get_checkpoint_path(model_path):
new_path
=
model_path
.
split
(
'.index'
)[
0
]
if
new_path
!=
model_path
:
logger
.
warn
(
"
[SaverRestore] {} is corrected to {} when restoring the model
."
.
format
(
model_path
,
new_path
))
"
Checkpoint path {} is auto-corrected to {}
."
.
format
(
model_path
,
new_path
))
model_path
=
new_path
assert
os
.
path
.
isfile
(
model_path
)
or
os
.
path
.
isfile
(
model_path
+
'.index'
),
model_path
return
model_path
...
...
@@ -183,7 +183,8 @@ def dump_chkpt_vars(model_path):
def
is_training_name
(
name
):
"""
This is a hack temporarily used to improve logging. Do not use this function.
Guess if a name belongs to a training-only variables.
Only used internally to avoid too many logging. Do not use it.
Returns:
bool: Guess whether this tensor is something only used in training.
...
...
tensorpack/train/distributed.py
View file @
a53da5ab
...
...
@@ -3,6 +3,7 @@
# File: distributed.py
import
tensorflow
as
tf
import
re
from
six.moves
import
range
from
..utils
import
logger
...
...
@@ -110,6 +111,31 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
ps_var_grads
.
append
((
grad
,
new_v
))
return
ps_var_grads
@
staticmethod
def
_shadow_model_variables
(
shadow_vars
):
"""
Create shadow vars for model_variables as well, and add to the list of ``shadow_vars``.
Returns:
list of (shadow_model_var, local_model_var) used for syncing.
"""
curr_shadow_vars
=
set
([
v
.
name
for
v
in
shadow_vars
])
model_vars
=
tf
.
model_variables
()
shadow_model_vars
=
[]
for
v
in
model_vars
:
assert
v
.
name
.
startswith
(
'tower'
),
"Found some MODEL_VARIABLES created outside of the model!"
stripped_name
=
get_op_tensor_name
(
re
.
sub
(
'tower[0-9]+/'
,
''
,
v
.
name
))[
0
]
if
stripped_name
in
curr_shadow_vars
:
continue
new_v
=
tf
.
get_variable
(
stripped_name
,
dtype
=
v
.
dtype
.
base_dtype
,
initializer
=
v
.
initial_value
,
trainable
=
False
)
curr_shadow_vars
.
add
(
stripped_name
)
# avoid duplicated shadow_model_vars
shadow_vars
.
append
(
new_v
)
shadow_model_vars
.
append
((
new_v
,
v
))
# only need to sync model_var from one tower
return
shadow_model_vars
def
_apply_gradients_and_copy
(
self
,
raw_grad_list
,
ps_var_grads
):
"""
Args:
...
...
@@ -142,7 +168,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
with
tf
.
device
(
self
.
param_server_device
):
gs
=
get_global_step_var
()
assert
gs
.
device
,
gs
.
device
self
.
model
.
get_optimizer
()
# TODO in global scope, not local
# do this before super.setup because input_source my need global step
super
(
DistributedReplicatedTrainer
,
self
)
.
_setup
()
...
...
@@ -161,16 +186,27 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
with
tf
.
device
(
self
.
param_server_device
):
ps_var_grads
=
DistributedReplicatedTrainer
.
_apply_shadow_vars
(
avg_grads
)
var_update_ops
=
self
.
_apply_gradients_and_copy
(
grad_list
,
ps_var_grads
)
self
.
_shadow_vars
=
[
v
for
(
_
,
v
)
in
ps_var_grads
]
self
.
_shadow_vars
=
[
v
for
(
_
,
v
)
in
ps_var_grads
]
self
.
_shadow_model_vars
=
DistributedReplicatedTrainer
.
_shadow_model_variables
(
self
.
_shadow_vars
)
main_fetch
=
tf
.
group
(
*
var_update_ops
,
name
=
'main_fetches'
)
self
.
train_op
=
self
.
add_sync_queues_and_barrier
(
'post_copy_barrier'
,
[
main_fetch
])
cb
=
RunOp
(
self
.
get_post_init_ops
,
# initial local_vars syncing
cb
=
RunOp
(
self
.
get_initial_sync_op
,
run_before
=
True
,
run_as_trigger
=
False
,
verbose
=
True
)
cb
.
chief_only
=
False
self
.
register_callback
(
cb
)
# model_variables syncing
if
len
(
self
.
_shadow_model_vars
)
and
self
.
is_chief
:
cb
=
RunOp
(
self
.
get_sync_model_vars_op
,
run_before
=
False
,
run_as_trigger
=
True
,
verbose
=
True
)
logger
.
warn
(
"For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequenctly."
)
self
.
register_callback
(
cb
)
self
.
_set_session_creator
()
def
_set_session_creator
(
self
):
...
...
@@ -230,26 +266,38 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
return
tf
.
group
(
*
queue_ops
)
def
get_post_init_ops
(
self
):
# Copy initialized variables for variables on the parameter server
# to the local copy of the variable.
def
get_initial_sync_op
(
self
):
"""
Get the op to copy-initialized all local variables from PS.
"""
def
strip_port
(
s
):
if
s
.
endswith
(
':0'
):
return
s
[:
-
2
]
return
s
local_vars
=
tf
.
local_variables
()
local_var_by_name
=
dict
([(
strip_port
(
v
.
name
),
v
)
for
v
in
local_vars
])
post_init_ops
=
[]
ops
=
[]
nr_shadow_vars
=
len
(
self
.
_shadow_vars
)
for
v
in
self
.
_shadow_vars
:
vname
=
strip_port
(
v
.
name
)
for
i
in
range
(
self
.
nr_gpu
):
name
=
'tower
%
s/
%
s'
%
(
i
,
vname
)
if
name
in
local_var_by_name
:
copy_to
=
local_var_by_name
[
name
]
post_init_ops
.
append
(
copy_to
.
assign
(
v
.
read_value
()))
else
:
logger
.
warn
(
"Global variable {} doesn't match a corresponding local var"
.
format
(
v
.
name
))
return
tf
.
group
(
*
post_init_ops
,
name
=
'sync_variables_from_ps'
)
assert
name
in
local_var_by_name
,
\
"Shadow variable {} doesn't match a corresponding local variable!"
.
format
(
v
.
name
)
copy_to
=
local_var_by_name
[
name
]
# logger.info("{} -> {}".format(v.name, copy_to.name))
ops
.
append
(
copy_to
.
assign
(
v
.
read_value
()))
return
tf
.
group
(
*
ops
,
name
=
'sync_{}_variables_from_ps'
.
format
(
nr_shadow_vars
))
def
get_sync_model_vars_op
(
self
):
"""
Get the op to sync local model_variables to PS.
"""
ops
=
[]
for
(
shadow_v
,
local_v
)
in
self
.
_shadow_model_vars
:
ops
.
append
(
shadow_v
.
assign
(
local_v
.
read_value
()))
assert
len
(
ops
)
return
tf
.
group
(
*
ops
,
name
=
'sync_{}_model_variables_to_ps'
.
format
(
len
(
ops
)))
@
property
def
vs_name_for_predictor
(
self
):
...
...
tensorpack/train/feedfree.py
View file @
a53da5ab
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
six.moves
import
zip
from
..tfutils.tower
import
TowerContext
,
get_current_tower_context
from
.input_source
import
QueueInput
,
FeedfreeInput
...
...
@@ -64,20 +65,18 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient"""
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
# assume single cost
# opt may be created under first-tower variable scope (which is '')
opt
=
self
.
model
.
get_optimizer
()
# GATE_NONE faster?
varlist
=
tf
.
trainable_variables
()
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
ctx
.
has_own_variables
and
ctx
.
vs_name
:
# only optimize w.r.t vars in this tower
# TODO use ctx.vars?
varlist
=
[
v
for
v
in
varlist
if
v
.
op
.
name
.
startswith
(
ctx
.
vs_name
+
'/'
)]
grads
=
opt
.
compute_
gradients
(
grads
=
tf
.
gradients
(
cost
,
var
_list
=
var
list
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
list
(
zip
(
grads
,
varlist
))
return
cost
,
grads
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment