Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dadd971c
Commit
dadd971c
authored
Jun 22, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
clean-up TowerContext, pass tower index into it. (a better solution to #310)
parent
c2edd999
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
37 deletions
+21
-37
tensorpack/predict/base.py
tensorpack/predict/base.py
+3
-1
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+15
-34
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+3
-2
No files found.
tensorpack/predict/base.py
View file @
dadd971c
...
@@ -176,6 +176,7 @@ class PredictorTowerBuilder(object):
...
@@ -176,6 +176,7 @@ class PredictorTowerBuilder(object):
tower (int): the tower will be built on device '/gpu:{tower}', or
tower (int): the tower will be built on device '/gpu:{tower}', or
'/cpu:0' if tower is -1.
'/cpu:0' if tower is -1.
"""
"""
toweridx
=
max
(
tower
,
0
)
# if CPU, named the tower as 0
towername
=
TowerContext
.
get_predict_tower_name
(
tower
,
self
.
_prefix
)
towername
=
TowerContext
.
get_predict_tower_name
(
tower
,
self
.
_prefix
)
if
self
.
_prefix
:
if
self
.
_prefix
:
msg
=
"Building predictor graph {} on gpu={} with prefix='{}' ..."
.
format
(
msg
=
"Building predictor graph {} on gpu={} with prefix='{}' ..."
.
format
(
...
@@ -187,7 +188,8 @@ class PredictorTowerBuilder(object):
...
@@ -187,7 +188,8 @@ class PredictorTowerBuilder(object):
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
with
tf
.
name_scope
(
None
),
\
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
TOWER_FREEZE_KEYS
),
\
freeze_collection
(
TOWER_FREEZE_KEYS
),
\
TowerContext
(
towername
,
device
=
device
,
is_training
=
False
):
tf
.
device
(
device
),
\
TowerContext
(
towername
,
is_training
=
False
,
index
=
toweridx
):
self
.
_fn
(
tower
)
self
.
_fn
(
tower
)
# useful only when the placeholders don't have tower prefix
# useful only when the placeholders don't have tower prefix
...
...
tensorpack/tfutils/tower.py
View file @
dadd971c
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
re
from
..utils.naming
import
PREDICT_TOWER
from
..utils.naming
import
PREDICT_TOWER
__all__
=
[
'get_current_tower_context'
,
'TowerContext'
]
__all__
=
[
'get_current_tower_context'
,
'TowerContext'
]
...
@@ -16,24 +15,27 @@ class TowerContext(object):
...
@@ -16,24 +15,27 @@ class TowerContext(object):
""" A context where the current model is being built in. """
""" A context where the current model is being built in. """
def
__init__
(
self
,
tower_name
,
def
__init__
(
self
,
tower_name
,
device
=
None
,
is_training
=
None
,
is_training
=
None
,
index
=
0
,
var_strategy
=
'shared'
,
var_strategy
=
'shared'
,
vs_name
=
None
):
vs_name
=
None
):
"""
"""
Args:
Args:
tower_name (str):
'tower0', 'towerp0', or ''
tower_name (str):
The name scope of the tower. Currently used
device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
values are like: 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name.
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower
var_strategy (str): either 'shared' or 'replicated'.
var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
'replicated' mode. Defaults to be tower_name.
"""
"""
self
.
_name
=
tower_name
self
.
_name
=
tower_name
self
.
_device
=
device
if
is_training
is
None
:
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
PREDICT_TOWER
)
is_training
=
not
self
.
_name
.
startswith
(
PREDICT_TOWER
)
self
.
_is_training
=
is_training
self
.
_is_training
=
bool
(
is_training
)
self
.
_index
=
index
assert
var_strategy
in
[
'replicated'
,
'shared'
],
var_strategy
assert
var_strategy
in
[
'replicated'
,
'shared'
],
var_strategy
self
.
_var_strategy
=
var_strategy
self
.
_var_strategy
=
var_strategy
...
@@ -49,11 +51,11 @@ class TowerContext(object):
...
@@ -49,11 +51,11 @@ class TowerContext(object):
@
property
@
property
def
is_main_training_tower
(
self
):
def
is_main_training_tower
(
self
):
return
self
.
is_training
and
(
self
.
_name
==
''
or
self
.
_name
==
'tower0'
)
return
self
.
is_training
and
self
.
_index
==
0
@
property
@
property
def
is_main_tower
(
self
):
def
is_main_tower
(
self
):
return
self
.
_
name
==
''
or
self
.
_name
==
'tower0'
return
self
.
_
index
==
0
@
property
@
property
def
is_training
(
self
):
def
is_training
(
self
):
...
@@ -67,37 +69,17 @@ class TowerContext(object):
...
@@ -67,37 +69,17 @@ class TowerContext(object):
def
name
(
self
):
def
name
(
self
):
return
self
.
_name
return
self
.
_name
# TODO remove this and add something like `tower.variables`
# variable_scope name
# variable_scope name
@
property
@
property
def
vs_name
(
self
):
def
vs_name
(
self
):
return
self
.
_vs_name
return
self
.
_vs_name
# TODO pass index into the constructor
@
property
@
property
def
index
(
self
):
def
index
(
self
):
if
self
.
_name
==
''
:
return
self
.
_index
return
0
idx
=
re
.
findall
(
'[0-9]+$'
,
self
.
_name
)
if
len
(
idx
)
==
0
:
return
0
return
int
(
idx
[
0
])
@
property
def
device
(
self
):
return
self
.
_device
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
if
self
.
is_main_tower
:
return
graph
.
get_tensor_by_name
(
name
)
if
name
.
startswith
(
PREDICT_TOWER
):
predict_tower_prefix
=
'{}[0-9]+/'
.
format
(
PREDICT_TOWER
)
newname
=
re
.
sub
(
predict_tower_prefix
,
''
,
name
)
try
:
return
graph
.
get_tensor_by_name
(
newname
)
except
KeyError
:
newname
=
re
.
sub
(
predict_tower_prefix
,
'tower0/'
,
name
)
return
graph
.
get_tensor_by_name
(
newname
)
# TODO something similar for training
@
staticmethod
@
staticmethod
def
get_predict_tower_name
(
towerid
=
0
,
prefix
=
''
):
def
get_predict_tower_name
(
towerid
=
0
,
prefix
=
''
):
"""
"""
...
@@ -124,15 +106,14 @@ class TowerContext(object):
...
@@ -124,15 +106,14 @@ class TowerContext(object):
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
self
.
vs_name
))
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
self
.
vs_name
))
else
:
else
:
if
self
.
is_training
:
if
self
.
is_training
:
reuse
=
self
.
index
>
0
reuse
=
self
.
_
index
>
0
if
reuse
is
True
:
if
reuse
is
True
:
# clear old name_scope and re-enter the current variable_scope
self
.
_ctxs
.
append
(
tf
.
name_scope
(
None
))
self
.
_ctxs
.
append
(
tf
.
name_scope
(
None
))
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
))
tf
.
get_variable_scope
(),
reuse
=
True
))
# if not training, should handle vs outside (TODO not good)
# if not training, should handle vs outside (TODO not good)
self
.
_ctxs
.
append
(
tf
.
name_scope
(
self
.
_name
))
self
.
_ctxs
.
append
(
tf
.
name_scope
(
self
.
_name
))
if
self
.
_device
is
not
None
:
self
.
_ctxs
.
append
(
tf
.
device
(
self
.
_device
))
for
c
in
self
.
_ctxs
:
for
c
in
self
.
_ctxs
:
c
.
__enter__
()
c
.
__enter__
()
...
...
tensorpack/train/multigpu.py
View file @
dadd971c
...
@@ -81,9 +81,10 @@ class MultiGPUTrainerBase(Trainer):
...
@@ -81,9 +81,10 @@ class MultiGPUTrainerBase(Trainer):
for
idx
,
t
in
enumerate
(
towers
):
for
idx
,
t
in
enumerate
(
towers
):
device
=
devices
[
idx
]
if
devices
is
not
None
else
'/gpu:{}'
.
format
(
t
)
device
=
devices
[
idx
]
if
devices
is
not
None
else
'/gpu:{}'
.
format
(
t
)
with
TowerContext
(
with
tf
.
device
(
device
),
TowerContext
(
'tower{}'
.
format
(
idx
),
'tower{}'
.
format
(
idx
),
device
=
device
,
is_training
=
True
,
is_training
=
True
,
index
=
idx
,
var_strategy
=
var_strategy
,
var_strategy
=
var_strategy
,
vs_name
=
vs_names
[
idx
]):
vs_name
=
vs_names
[
idx
]):
if
idx
==
t
:
if
idx
==
t
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment