Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
787be08e
Commit
787be08e
authored
Aug 22, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Deprecate warning for old modeldesc interface.
parent
3700a803
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
26 deletions
+39
-26
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+36
-23
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+3
-3
No files found.
tensorpack/graph_builder/model_desc.py
View file @
787be08e
...
@@ -9,7 +9,6 @@ from ..utils import logger
...
@@ -9,7 +9,6 @@ from ..utils import logger
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..input_source
import
InputSource
from
..models.regularize
import
regularize_cost_from_collection
from
..models.regularize
import
regularize_cost_from_collection
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
...
@@ -95,10 +94,16 @@ class ModelDescBase(object):
...
@@ -95,10 +94,16 @@ class ModelDescBase(object):
def
get_inputs_desc
(
self
):
def
get_inputs_desc
(
self
):
"""
"""
Returns:
Returns:
a list of :class:`InputDesc`.
A list of :class:`InputDesc`, which describes the inputs of this model.
The result is cached for each instance of :class:`ModelDescBase`.
"""
"""
try
:
try
:
return
self
.
_get_inputs
()
ret
=
self
.
_get_inputs
()
log_deprecated
(
"ModelDescBase._get_inputs() interface"
,
"Use inputs() instead!"
,
"2019-03-30"
)
return
ret
except
NotImplementedError
:
except
NotImplementedError
:
with
tf
.
Graph
()
.
as_default
()
as
G
:
# create these placeholder in a temporary graph
with
tf
.
Graph
()
.
as_default
()
as
G
:
# create these placeholder in a temporary graph
inputs
=
self
.
inputs
()
inputs
=
self
.
inputs
()
...
@@ -106,6 +111,14 @@ class ModelDescBase(object):
...
@@ -106,6 +111,14 @@ class ModelDescBase(object):
assert
p
.
graph
==
G
,
"Placeholders returned by inputs() should be created inside inputs()!"
assert
p
.
graph
==
G
,
"Placeholders returned by inputs() should be created inside inputs()!"
return
[
InputDesc
.
from_placeholder
(
p
)
for
p
in
inputs
]
return
[
InputDesc
.
from_placeholder
(
p
)
for
p
in
inputs
]
@
property
def
input_names
(
self
):
"""
Returns:
[str]: the names of all the inputs.
"""
return
[
k
.
name
for
k
in
self
.
get_inputs_desc
()]
def
_get_inputs
(
self
):
def
_get_inputs
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -116,7 +129,8 @@ class ModelDescBase(object):
...
@@ -116,7 +129,8 @@ class ModelDescBase(object):
The placeholders __have to__ be created inside this method.
The placeholders __have to__ be created inside this method.
Don't return placeholders created in other methods.
Don't return placeholders created in other methods.
Also, you should not call this method by yourself.
Also, you should never call this method by yourself.
Returns:
Returns:
a list of `tf.placeholder`, to be converted to :class:`InputDesc`.
a list of `tf.placeholder`, to be converted to :class:`InputDesc`.
...
@@ -128,7 +142,7 @@ class ModelDescBase(object):
...
@@ -128,7 +142,7 @@ class ModelDescBase(object):
Build the whole symbolic graph.
Build the whole symbolic graph.
This is supposed to be part of the "tower function" when used with :class:`TowerTrainer`.
This is supposed to be part of the "tower function" when used with :class:`TowerTrainer`.
A subclass is expected to
overwrite
this method.
A subclass is expected to
implement
this method.
Args:
Args:
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
...
@@ -138,24 +152,14 @@ class ModelDescBase(object):
...
@@ -138,24 +152,14 @@ class ModelDescBase(object):
may require it to return necessary information to build the trainer.
may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
"""
"""
if
len
(
args
)
==
1
:
assert
len
(
args
)
==
len
(
self
.
get_inputs_desc
()),
\
arg
=
args
[
0
]
if
isinstance
(
arg
,
InputSource
):
inputs
=
arg
.
get_input_tensors
()
# remove in the future?
log_deprecated
(
"build_graph(InputSource)"
,
"Call with tensors in positional args instead."
,
"2018-03-31"
)
elif
isinstance
(
arg
,
(
list
,
tuple
)):
inputs
=
arg
log_deprecated
(
"build_graph([Tensor])"
,
"Call with positional args instead."
,
"2018-03-31"
)
else
:
inputs
=
[
arg
]
else
:
inputs
=
args
assert
len
(
inputs
)
==
len
(
self
.
get_inputs_desc
()),
\
"Number of inputs passed to the graph != number of inputs defined "
\
"Number of inputs passed to the graph != number of inputs defined "
\
"in ModelDesc! ({} != {})"
.
format
(
len
(
inputs
),
len
(
self
.
get_inputs_desc
()))
"in ModelDesc! ({} != {})"
.
format
(
len
(
args
),
len
(
self
.
get_inputs_desc
()))
return
self
.
_build_graph
(
inputs
)
log_deprecated
(
"ModelDescBase._build_graph() interface"
,
"Use build_graph() instead!"
,
"2019-03-30"
)
return
self
.
_build_graph
(
args
)
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
"""
"""
...
@@ -187,6 +191,10 @@ class ModelDesc(ModelDescBase):
...
@@ -187,6 +191,10 @@ class ModelDesc(ModelDescBase):
and applies the collection
and applies the collection
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
"""
"""
log_deprecated
(
"get_cost() and self.cost"
,
"Return the cost tensor directly in build_graph() instead!"
,
"2019-03-30"
)
cost
=
self
.
_get_cost
()
cost
=
self
.
_get_cost
()
reg_cost
=
regularize_cost_from_collection
()
reg_cost
=
regularize_cost_from_collection
()
if
reg_cost
.
op
.
type
!=
'Const'
:
if
reg_cost
.
op
.
type
!=
'Const'
:
...
@@ -211,7 +219,12 @@ class ModelDesc(ModelDescBase):
...
@@ -211,7 +219,12 @@ class ModelDesc(ModelDescBase):
a :class:`tf.train.Optimizer` instance.
a :class:`tf.train.Optimizer` instance.
"""
"""
try
:
try
:
return
self
.
_get_optimizer
()
ret
=
self
.
_get_optimizer
()
log_deprecated
(
"ModelDescBase._get_optimizer() interface"
,
"Use optimizer() instead!"
,
"2019-03-30"
)
return
ret
except
NotImplementedError
:
except
NotImplementedError
:
pass
pass
return
self
.
optimizer
()
return
self
.
optimizer
()
...
...
tensorpack/tfutils/tower.py
View file @
787be08e
...
@@ -263,9 +263,9 @@ class TowerFuncWrapper(object):
...
@@ -263,9 +263,9 @@ class TowerFuncWrapper(object):
They are used to figure out the names for the input tensors.
They are used to figure out the names for the input tensors.
"""
"""
assert
callable
(
tower_fn
),
tower_fn
assert
callable
(
tower_fn
),
tower_fn
inputs_desc_names
=
[
k
.
name
for
k
in
inputs_desc
]
self
.
_
inputs_desc_names
=
[
k
.
name
for
k
in
inputs_desc
]
assert
len
(
set
(
inputs_desc_names
))
==
len
(
inputs_desc_names
),
\
assert
len
(
set
(
self
.
_inputs_desc_names
))
==
len
(
self
.
_
inputs_desc_names
),
\
"Duplicated names in inputs_desc! "
+
str
(
inputs_desc_names
)
"Duplicated names in inputs_desc! "
+
str
(
self
.
_
inputs_desc_names
)
self
.
_tower_fn
=
tower_fn
self
.
_tower_fn
=
tower_fn
self
.
_inputs_desc
=
inputs_desc
self
.
_inputs_desc
=
inputs_desc
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment