Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
39fa4656
Commit
39fa4656
authored
Mar 16, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use inputs() and tf.placeholder in ModelDesc (#318)
parent
215a4d6d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
31 deletions
+67
-31
docs/tutorial/training-interface.md
docs/tutorial/training-interface.md
+2
-2
examples/basics/cifar-convnet.py
examples/basics/cifar-convnet.py
+3
-4
examples/basics/mnist-convnet.py
examples/basics/mnist-convnet.py
+4
-5
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+57
-19
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+1
-1
No files found.
docs/tutorial/training-interface.md
View file @
39fa4656
...
...
@@ -17,8 +17,8 @@ expects 4 arguments to setup the graph: `InputDesc`, `InputSource`, get_cost fun
```
python
class
MyModel
(
ModelDesc
):
def
_get_
inputs
(
self
):
return
[
InputDesc
(
...
),
InputDesc
(
...
)
]
def
inputs
(
self
):
return
[
tf
.
placeholder
(
dtype
,
shape
,
name
),
tf
.
placeholder
(
dtype
,
shape
,
name
),
...
]
def
_build_graph
(
self
,
inputs
):
tensorA
,
tensorB
=
inputs
...
...
examples/basics/cifar-convnet.py
View file @
39fa4656
...
...
@@ -26,10 +26,9 @@ class Model(ModelDesc):
super
(
Model
,
self
)
.
__init__
()
self
.
cifar_classnum
=
cifar_classnum
def
_get_inputs
(
self
):
return
[
InputDesc
(
tf
.
float32
,
(
None
,
30
,
30
,
3
),
'input'
),
InputDesc
(
tf
.
int32
,
(
None
,),
'label'
)
]
def
inputs
(
self
):
return
[
tf
.
placeholder
(
tf
.
float32
,
(
None
,
30
,
30
,
3
),
'input'
),
tf
.
placeholder
(
tf
.
int32
,
(
None
,),
'label'
)]
def
_build_graph
(
self
,
inputs
):
image
,
label
=
inputs
...
...
examples/basics/mnist-convnet.py
View file @
39fa4656
...
...
@@ -20,13 +20,12 @@ IMAGE_SIZE = 28
class
Model
(
ModelDesc
):
def
_get_
inputs
(
self
):
def
inputs
(
self
):
"""
Define all the inputs (with type, shape, name) that
the graph will need.
Define all the inputs (with type, shape, name) that the graph will need.
"""
return
[
InputDesc
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
'input'
),
InputDesc
(
tf
.
int32
,
(
None
,),
'label'
)]
return
[
tf
.
placeholder
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
'input'
),
tf
.
placeholder
(
tf
.
int32
,
(
None
,),
'label'
)]
def
_build_graph
(
self
,
inputs
):
"""This function should build the model which takes the input variables
...
...
tensorpack/graph_builder/model_desc.py
View file @
39fa4656
...
...
@@ -3,11 +3,10 @@
# File: model_desc.py
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
namedtuple
import
tensorflow
as
tf
import
six
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.develop
import
log_deprecated
from
..tfutils.gradproc
import
FilterNoneGrad
...
...
@@ -38,9 +37,10 @@ class InputDesc(
if
any
(
k
in
name
for
k
in
[
':'
,
'/'
,
' '
]):
raise
ValueError
(
"Invalid InputDesc name: '{}'"
.
format
(
name
))
self
=
super
(
InputDesc
,
cls
)
.
__new__
(
cls
,
type
,
shape
,
name
)
self
.
_cached_placeholder
=
None
self
.
_cached_placeholder
=
{}
return
self
# TODO this method seems unused outside this class
def
build_placeholder
(
self
):
"""
Build a tf.placeholder from the metadata.
...
...
@@ -51,8 +51,7 @@ class InputDesc(
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tf
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
name
=
self
.
name
)
if
self
.
_cached_placeholder
is
None
:
self
.
_cached_placeholder
=
ret
# cached_placeholder only caches the prefix='' case
self
.
_register_cached_placeholder
(
ret
)
return
ret
# cannot memoize here, because InputDesc is hashed by its fields.
...
...
@@ -63,28 +62,67 @@ class InputDesc(
Returns:
tf.Tensor:
"""
if
self
.
_cached_placeholder
is
not
None
:
return
self
.
_cached_placeholder
g
=
tf
.
get_default_graph
()
if
g
in
self
.
_cached_placeholder
:
return
self
.
_cached_placeholder
[
g
]
else
:
return
self
.
build_placeholder
()
def
_register_cached_placeholder
(
self
,
placeholder
):
graph
=
placeholder
.
graph
assert
graph
not
in
self
.
_cached_placeholder
,
\
"Placeholder for this InputDesc had been created before! This is a bug."
self
.
_cached_placeholder
[
graph
]
=
placeholder
@
staticmethod
def
from_placeholder
(
placeholder
):
name
=
placeholder
.
op
.
name
if
name
.
endswith
(
'_1'
)
or
name
.
endswith
(
'_2'
):
logger
.
error
(
"Creating InputDesc from a placeholder named {}."
.
format
(
name
))
logger
.
error
(
"You might have mistakenly created this placeholder multiple times!"
)
ret
=
InputDesc
(
placeholder
.
dtype
,
tuple
(
placeholder
.
shape
.
as_list
()),
name
)
ret
.
_register_cached_placeholder
(
placeholder
)
return
ret
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDescBase
(
object
):
""" Base class for a model description.
"""
Base class for a model description.
"""
@
memoized
def
get_inputs_desc
(
self
):
"""
Returns:
list[:class:`InputDesc`]: list of the underlying
:class:`InputDesc`.
a list of
:class:`InputDesc`.
"""
try
:
return
self
.
_get_inputs
()
except
NotImplementedError
:
with
tf
.
Graph
()
.
as_default
():
# create these placeholder in a temporary graph
inputs
=
self
.
inputs
()
return
[
InputDesc
.
from_placeholder
(
p
)
for
p
in
inputs
]
@
abstractmethod
def
_get_inputs
(
self
):
"""
:returns: a list of InputDesc
Returns:
a list of :class:`InputDesc`.
"""
raise
NotImplementedError
()
def
inputs
(
self
):
"""
__Create__ and returns a list of placeholders.
To be implemented by subclass.
The placeholders __have to__ be created inside this function.
Returns:
a list of `tf.placeholder`, to be converted to :class:`InputDesc`.
"""
raise
NotImplementedError
()
def
build_graph
(
self
,
*
args
):
"""
...
...
@@ -93,13 +131,12 @@ class ModelDescBase(object):
By default it will call :meth:`_build_graph` with a list of input tensors.
Args:
args ([tf.Tensor]): tensors that matches the list of
:class:`InputDesc` defined by ``_get_inputs``.
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
Returns:
In general it returns nothing, but a subclass (e.g.
:class:`ModelDesc` may require it to return necessary information
to build the trainer.
:class:`ModelDesc`
)
may require it to return necessary information
(e.g. cost)
to build the trainer.
"""
if
len
(
args
)
==
1
:
arg
=
args
[
0
]
...
...
@@ -122,7 +159,8 @@ class ModelDescBase(object):
def
_build_graph
(
self
,
inputs
):
"""
This is an old interface which takes a list of tensors, instead of positional arguments.
This is an alternative interface which takes a list of tensors, instead of positional arguments.
By default :meth:`build_graph` will call this method.
"""
pass
...
...
tensorpack/input_source/input_source.py
View file @
39fa4656
...
...
@@ -41,7 +41,7 @@ class PlaceholderInput(InputSource):
Just produce placeholders as input tensors.
"""
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder
()
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
v
.
build_placeholder
_reuse
()
for
v
in
inputs
]
def
_get_input_tensors
(
self
):
return
self
.
_all_placehdrs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment