Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e121701a
Commit
e121701a
authored
Oct 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Allow TowerContext to use both ns_name and vs_name
parent
0addcdc6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
52 additions
and
36 deletions
+52
-36
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-2
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+2
-2
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+1
-1
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+38
-30
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+9
-1
No files found.
tensorpack/callbacks/inference_runner.py
View file @
e121701a
...
...
@@ -149,7 +149,7 @@ class InferenceRunner(InferenceRunnerBase):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
self
.
_tower_name
,
vs_name
=
''
,
device
=
0
)
.
build
(
# TODO fix vs_name and maybe device
vs_name
=
self
.
trainer
.
_main_tower_vs_name
,
device
=
0
)
.
build
(
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_tower_handle
=
self
.
trainer
.
tower_func
.
towers
[
-
1
]
...
...
@@ -224,7 +224,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
tower_name
=
self
.
_tower_names
[
idx
]
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
''
,
device
=
t
)
.
build
(
# TODO fix vs_name and maybe device
vs_name
=
self
.
trainer
.
_main_tower_vs_name
,
device
=
t
)
.
build
(
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_handles
.
append
(
self
.
trainer
.
tower_func
.
towers
[
-
1
])
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
e121701a
...
...
@@ -26,7 +26,6 @@ class SimplePredictBuilder(GraphBuilder):
vs_name (str):
device (int):
"""
# TODO does vs_name work properly here when different from ns_name?
self
.
_ns_name
=
ns_name
self
.
_vs_name
=
vs_name
...
...
@@ -56,7 +55,8 @@ class SimplePredictBuilder(GraphBuilder):
with
tf
.
device
(
self
.
_device
),
\
self
.
_maybe_open_vs
(),
\
TowerContext
(
self
.
_ns_name
,
is_training
=
False
),
\
TowerContext
(
self
.
_ns_name
,
is_training
=
False
,
vs_name
=
self
.
_vs_name
),
\
freeze_collection
(
TOWER_FREEZE_KEYS
+
[
tf
.
GraphKeys
.
UPDATE_OPS
]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
...
...
tensorpack/graph_builder/training.py
View file @
e121701a
...
...
@@ -88,7 +88,7 @@ class DataParallelBuilder(GraphBuilder):
tower_names
[
idx
],
is_training
=
True
,
index
=
idx
,
use_vs
=
usevs
):
vs_name
=
tower_names
[
idx
]
if
usevs
else
''
):
logger
.
info
(
"Building graph for training tower {} on device {}..."
.
format
(
idx
,
device
))
# When use_vs is True, use LOCAL_VARIABLES,
...
...
tensorpack/tfutils/tower.py
View file @
e121701a
...
...
@@ -7,6 +7,7 @@ import tensorflow as tf
from
six.moves
import
zip
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
from
.common
import
get_tf_version_number
,
get_op_or_tensor_by_name
,
get_op_tensor_name
__all__
=
[
'get_current_tower_context'
,
'TowerContext'
,
'TowerFuncWrapper'
]
...
...
@@ -17,30 +18,30 @@ _CurrentTowerContext = None
class
TowerContext
(
object
):
""" A context where the current model is being built in. """
def
__init__
(
self
,
tower_name
,
is_training
,
index
=
0
,
use_vs
=
False
):
def
__init__
(
self
,
tower_name
,
is_training
,
index
=
0
,
vs_name
=
''
):
"""
Args:
tower_name (str): The name scope of the tower.
is_training (bool):
index (int): index of this tower, only used in training.
use_vs (bool
): Open a new variable scope with this name.
vs_name (str
): Open a new variable scope with this name.
"""
self
.
_name
=
tower_name
self
.
_is_training
=
bool
(
is_training
)
if
not
self
.
_is_training
:
assert
index
==
0
and
not
use_vs
,
\
"
use_vs and index are only use
d in training!"
assert
index
==
0
,
\
"
TowerContext(index) is only vali
d in training!"
self
.
_index
=
int
(
index
)
if
use_vs
:
self
.
_vs_name
=
self
.
_name
assert
len
(
self
.
_name
)
else
:
self
.
_vs_name
=
''
self
.
_vs_name
=
vs_name
if
len
(
vs_name
):
assert
len
(
tower_name
),
"TowerContext(vs_name) cannot be used with an empty tower_name!"
self
.
_initial_vs_reuse
=
tf
.
get_variable_scope
()
.
reuse
if
self
.
has_own_variables
:
assert
not
tf
.
get_variable_scope
()
.
reuse
,
"reuse=True in tower {}!"
.
format
(
tower_name
)
assert
not
self
.
_initial_vs_reuse
,
\
"Cannot create tower {} with reuse=True!"
.
format
(
tower_name
)
@
property
def
is_main_training_tower
(
self
):
...
...
@@ -55,7 +56,9 @@ class TowerContext(object):
"""
Whether this tower is supposed to have its own variables.
"""
return
self
.
is_main_training_tower
or
len
(
self
.
_vs_name
)
>
0
return
self
.
is_main_training_tower
or
\
(
self
.
is_training
and
len
(
self
.
_vs_name
)
>
0
)
or
\
(
not
self
.
is_training
and
len
(
self
.
_vs_name
)
>
0
and
not
self
.
_initial_vs_reuse
)
# TODO clarify the interface on name/vs_name/ns_name.
# TODO in inference, vs_name may need to be different from ns_name.i
...
...
@@ -72,6 +75,7 @@ class TowerContext(object):
def
ns_name
(
self
):
return
self
.
_name
# TODO another method to filter by ns_name
def
filter_vars_by_vs_name
(
self
,
varlist
):
"""
Filter the list and only keep those under the current variable scope.
...
...
@@ -93,32 +97,36 @@ class TowerContext(object):
def
index
(
self
):
return
self
.
_index
@
call_only_once
def
_get_scopes
(
self
):
if
not
len
(
self
.
_name
):
return
[]
ret
=
[]
# either the Tower was originally created with reuse,
# or a training tower without vs has to use reuse.
reuse
=
(
self
.
is_training
and
self
.
_index
>
0
and
not
self
.
has_own_variables
)
or
self
.
_initial_vs_reuse
if
len
(
self
.
_vs_name
):
ret
.
append
(
tf
.
variable_scope
(
self
.
_vs_name
,
reuse
=
reuse
))
else
:
if
reuse
:
ret
.
append
(
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
))
# always clear existing ns # TODO check existing ns
if
len
(
self
.
_name
)
and
self
.
_name
!=
self
.
_vs_name
:
ret
.
append
(
tf
.
name_scope
(
self
.
_name
+
'/'
))
return
ret
def
__enter__
(
self
):
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
"Cannot nest TowerContext!"
_CurrentTowerContext
=
self
self
.
_ctxs
=
[]
curr_vs
=
tf
.
get_variable_scope
()
assert
curr_vs
.
name
==
''
,
"Cannot nest TowerContext with an existing variable scope!"
if
len
(
self
.
_name
):
if
not
self
.
is_training
:
# if not training, should handle reuse outside
# but still good to clear name_scope first
self
.
_ctxs
.
append
(
tf
.
name_scope
(
None
))
self
.
_ctxs
.
append
(
tf
.
name_scope
(
self
.
_name
))
else
:
if
self
.
has_own_variables
:
if
len
(
self
.
_vs_name
):
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
self
.
_vs_name
))
else
:
self
.
_ctxs
.
append
(
tf
.
name_scope
(
self
.
_name
))
else
:
reuse
=
self
.
_index
>
0
if
reuse
:
self
.
_ctxs
.
append
(
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
))
self
.
_ctxs
.
append
(
tf
.
name_scope
(
self
.
_name
))
self
.
_ctxs
=
self
.
_get_scopes
()
for
c
in
self
.
_ctxs
:
c
.
__enter__
()
...
...
tensorpack/trainv2/base.py
View file @
e121701a
...
...
@@ -269,13 +269,21 @@ class TowerTrainer(Trainer):
input
.
setup
(
self
.
inputs_desc
)
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
''
,
ns_name
=
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
tower
=
self
.
tower_func
.
towers
[
tower_name
]
input_tensors
=
tower
.
get_tensors
(
input_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
return
OnlinePredictor
(
input_tensors
,
output_tensors
)
@
property
def
_main_tower_vs_name
(
self
):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return
""
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
TowerTrainer
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment