Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a77cc508
Commit
a77cc508
authored
Jul 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove vs_strategy from tower. Use vs_name in a cleaner way.
parent
4d2a7b4c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
28 deletions
+24
-28
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+5
-20
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+19
-8
No files found.
tensorpack/tfutils/tower.py
View file @
a77cc508
...
...
@@ -14,20 +14,15 @@ _CurrentTowerContext = None
class
TowerContext
(
object
):
""" A context where the current model is being built in. """
def
__init__
(
self
,
tower_name
,
is_training
=
None
,
index
=
0
,
var_strategy
=
'shared'
,
vs_name
=
None
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
,
index
=
0
,
vs_name
=
''
):
"""
Args:
tower_name (str): The name scope of the tower. Currently used
values are like: 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower
var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
vs_name (str): Open a variable scope with this name, if given.
"""
self
.
_name
=
tower_name
...
...
@@ -37,17 +32,7 @@ class TowerContext(object):
self
.
_index
=
int
(
index
)
assert
var_strategy
in
[
'replicated'
,
'shared'
],
var_strategy
self
.
_var_strategy
=
var_strategy
if
self
.
_var_strategy
==
'replicated'
:
assert
self
.
_name
if
vs_name
is
None
:
self
.
_vs_name
=
self
.
_name
else
:
self
.
_vs_name
=
vs_name
else
:
assert
vs_name
is
None
,
"vs_name is only valid in 'replicated' mode!"
self
.
_vs_name
=
''
@
property
def
is_main_training_tower
(
self
):
...
...
@@ -63,7 +48,7 @@ class TowerContext(object):
@
property
def
has_own_variables
(
self
):
return
self
.
_var_strategy
==
'replicated'
return
len
(
self
.
_vs_name
)
>
0
@
property
def
name
(
self
):
...
...
tensorpack/train/multigpu.py
View file @
a77cc508
...
...
@@ -70,22 +70,29 @@ class MultiGPUTrainerBase(Trainer):
if
devices
is
not
None
:
assert
len
(
devices
)
==
len
(
towers
)
tower_names
=
[
'tower{}'
.
format
(
idx
)
for
idx
in
range
(
len
(
towers
))]
keys_to_freeze
=
TOWER_FREEZE_KEYS
[:]
if
var_strategy
==
'replicated'
:
# TODO ugly
logger
.
info
(
"In replicated mode, UPDATE_OPS from all GPUs will be run."
)
keys_to_freeze
.
remove
(
tf
.
GraphKeys
.
UPDATE_OPS
)
# fix all Nones. TODO ugly
if
vs_names
is
not
None
:
assert
len
(
vs_names
)
==
len
(
towers
)
for
idx
,
name
in
enumerate
(
vs_names
):
if
name
is
None
:
vs_names
[
idx
]
=
tower_names
[
idx
]
else
:
vs_names
=
tower_names
else
:
assert
vs_names
is
None
if
vs_names
is
None
:
vs_names
=
[
None
]
*
len
(
towers
)
vs_names
=
[
''
]
*
len
(
towers
)
for
idx
,
t
in
enumerate
(
towers
):
device
=
devices
[
idx
]
if
devices
is
not
None
else
'/gpu:{}'
.
format
(
t
)
with
tf
.
device
(
device
),
TowerContext
(
'tower{}'
.
format
(
idx
)
,
tower_names
[
idx
]
,
is_training
=
True
,
index
=
idx
,
var_strategy
=
var_strategy
,
vs_name
=
vs_names
[
idx
]):
if
idx
==
t
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
...
...
@@ -279,17 +286,21 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
@
staticmethod
def
get_post_init_ops
():
# Copy initialized values for variables on GPU 0 to other GPUs.
global_vars
=
tf
.
global_variables
()
var_by_name
=
dict
([(
v
.
name
,
v
)
for
v
in
globa
l_vars
])
all_vars
=
tf
.
trainable_variables
()
# TODO model_variables?
var_by_name
=
dict
([(
v
.
name
,
v
)
for
v
in
al
l_vars
])
post_init_ops
=
[]
for
v
in
globa
l_vars
:
for
v
in
al
l_vars
:
split_name
=
v
.
name
.
split
(
'/'
)
if
not
v
.
name
.
startswith
(
'tower'
):
continue
# the master name doesn't have the towerx/ prefix
if
v
.
name
.
startswith
(
'tower0'
):
continue
# TODO some vars (EMA) may still startswith tower0
# in this trainer, the master name doesn't have the towerx/ prefix
split_name
=
split_name
[
1
:]
copy_from
=
var_by_name
[
'/'
.
join
(
split_name
)]
post_init_ops
.
append
(
v
.
assign
(
copy_from
.
read_value
()))
logger
.
info
(
"'sync_variables_from_tower0' includes {} operations."
.
format
(
len
(
post_init_ops
)))
return
tf
.
group
(
*
post_init_ops
,
name
=
'sync_variables_from_tower0'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment