Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
39fa4656
Commit
39fa4656
authored
Mar 16, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use inputs() and tf.placeholder in ModelDesc (#318)
parent
215a4d6d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
67 additions
and
31 deletions
+67
-31
docs/tutorial/training-interface.md
docs/tutorial/training-interface.md
+2
-2
examples/basics/cifar-convnet.py
examples/basics/cifar-convnet.py
+3
-4
examples/basics/mnist-convnet.py
examples/basics/mnist-convnet.py
+4
-5
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+57
-19
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+1
-1
No files found.
docs/tutorial/training-interface.md
View file @
39fa4656
...
@@ -17,8 +17,8 @@ expects 4 arguments to setup the graph: `InputDesc`, `InputSource`, get_cost fun
...
@@ -17,8 +17,8 @@ expects 4 arguments to setup the graph: `InputDesc`, `InputSource`, get_cost fun
```
python
```
python
class
MyModel
(
ModelDesc
):
class
MyModel
(
ModelDesc
):
def
_get_
inputs
(
self
):
def
inputs
(
self
):
return
[
InputDesc
(
...
),
InputDesc
(
...
)
]
return
[
tf
.
placeholder
(
dtype
,
shape
,
name
),
tf
.
placeholder
(
dtype
,
shape
,
name
),
...
]
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
tensorA
,
tensorB
=
inputs
tensorA
,
tensorB
=
inputs
...
...
examples/basics/cifar-convnet.py
View file @
39fa4656
...
@@ -26,10 +26,9 @@ class Model(ModelDesc):
...
@@ -26,10 +26,9 @@ class Model(ModelDesc):
super
(
Model
,
self
)
.
__init__
()
super
(
Model
,
self
)
.
__init__
()
self
.
cifar_classnum
=
cifar_classnum
self
.
cifar_classnum
=
cifar_classnum
def
_get_inputs
(
self
):
def
inputs
(
self
):
return
[
InputDesc
(
tf
.
float32
,
(
None
,
30
,
30
,
3
),
'input'
),
return
[
tf
.
placeholder
(
tf
.
float32
,
(
None
,
30
,
30
,
3
),
'input'
),
InputDesc
(
tf
.
int32
,
(
None
,),
'label'
)
tf
.
placeholder
(
tf
.
int32
,
(
None
,),
'label'
)]
]
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
image
,
label
=
inputs
image
,
label
=
inputs
...
...
examples/basics/mnist-convnet.py
View file @
39fa4656
...
@@ -20,13 +20,12 @@ IMAGE_SIZE = 28
...
@@ -20,13 +20,12 @@ IMAGE_SIZE = 28
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
def
_get_
inputs
(
self
):
def
inputs
(
self
):
"""
"""
Define all the inputs (with type, shape, name) that
Define all the inputs (with type, shape, name) that the graph will need.
the graph will need.
"""
"""
return
[
InputDesc
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
'input'
),
return
[
tf
.
placeholder
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
'input'
),
InputDesc
(
tf
.
int32
,
(
None
,),
'label'
)]
tf
.
placeholder
(
tf
.
int32
,
(
None
,),
'label'
)]
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
"""This function should build the model which takes the input variables
"""This function should build the model which takes the input variables
...
...
tensorpack/graph_builder/model_desc.py
View file @
39fa4656
...
@@ -3,11 +3,10 @@
...
@@ -3,11 +3,10 @@
# File: model_desc.py
# File: model_desc.py
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
namedtuple
from
collections
import
namedtuple
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.gradproc
import
FilterNoneGrad
...
@@ -38,9 +37,10 @@ class InputDesc(
...
@@ -38,9 +37,10 @@ class InputDesc(
if
any
(
k
in
name
for
k
in
[
':'
,
'/'
,
' '
]):
if
any
(
k
in
name
for
k
in
[
':'
,
'/'
,
' '
]):
raise
ValueError
(
"Invalid InputDesc name: '{}'"
.
format
(
name
))
raise
ValueError
(
"Invalid InputDesc name: '{}'"
.
format
(
name
))
self
=
super
(
InputDesc
,
cls
)
.
__new__
(
cls
,
type
,
shape
,
name
)
self
=
super
(
InputDesc
,
cls
)
.
__new__
(
cls
,
type
,
shape
,
name
)
self
.
_cached_placeholder
=
None
self
.
_cached_placeholder
=
{}
return
self
return
self
# TODO this method seems unused outside this class
def
build_placeholder
(
self
):
def
build_placeholder
(
self
):
"""
"""
Build a tf.placeholder from the metadata.
Build a tf.placeholder from the metadata.
...
@@ -51,8 +51,7 @@ class InputDesc(
...
@@ -51,8 +51,7 @@ class InputDesc(
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tf
.
placeholder
(
ret
=
tf
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
name
=
self
.
name
)
self
.
type
,
shape
=
self
.
shape
,
name
=
self
.
name
)
if
self
.
_cached_placeholder
is
None
:
self
.
_register_cached_placeholder
(
ret
)
self
.
_cached_placeholder
=
ret
# cached_placeholder only caches the prefix='' case
return
ret
return
ret
# cannot memoize here, because InputDesc is hashed by its fields.
# cannot memoize here, because InputDesc is hashed by its fields.
...
@@ -63,28 +62,67 @@ class InputDesc(
...
@@ -63,28 +62,67 @@ class InputDesc(
Returns:
Returns:
tf.Tensor:
tf.Tensor:
"""
"""
if
self
.
_cached_placeholder
is
not
None
:
g
=
tf
.
get_default_graph
()
return
self
.
_cached_placeholder
if
g
in
self
.
_cached_placeholder
:
return
self
.
build_placeholder
()
return
self
.
_cached_placeholder
[
g
]
else
:
return
self
.
build_placeholder
()
def
_register_cached_placeholder
(
self
,
placeholder
):
graph
=
placeholder
.
graph
assert
graph
not
in
self
.
_cached_placeholder
,
\
"Placeholder for this InputDesc had been created before! This is a bug."
self
.
_cached_placeholder
[
graph
]
=
placeholder
@
staticmethod
def
from_placeholder
(
placeholder
):
name
=
placeholder
.
op
.
name
if
name
.
endswith
(
'_1'
)
or
name
.
endswith
(
'_2'
):
logger
.
error
(
"Creating InputDesc from a placeholder named {}."
.
format
(
name
))
logger
.
error
(
"You might have mistakenly created this placeholder multiple times!"
)
ret
=
InputDesc
(
placeholder
.
dtype
,
tuple
(
placeholder
.
shape
.
as_list
()),
name
)
ret
.
_register_cached_placeholder
(
placeholder
)
return
ret
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDescBase
(
object
):
class
ModelDescBase
(
object
):
""" Base class for a model description.
"""
"""
Base class for a model description.
"""
@
memoized
@
memoized
def
get_inputs_desc
(
self
):
def
get_inputs_desc
(
self
):
"""
"""
Returns:
Returns:
list[:class:`InputDesc`]: list of the underlying
:class:`InputDesc`.
a list of
:class:`InputDesc`.
"""
"""
return
self
.
_get_inputs
()
try
:
return
self
.
_get_inputs
()
except
NotImplementedError
:
with
tf
.
Graph
()
.
as_default
():
# create these placeholder in a temporary graph
inputs
=
self
.
inputs
()
return
[
InputDesc
.
from_placeholder
(
p
)
for
p
in
inputs
]
@
abstractmethod
def
_get_inputs
(
self
):
def
_get_inputs
(
self
):
"""
"""
:returns: a list of InputDesc
Returns:
a list of :class:`InputDesc`.
"""
raise
NotImplementedError
()
def
inputs
(
self
):
"""
__Create__ and returns a list of placeholders.
To be implemented by subclass.
The placeholders __have to__ be created inside this function.
Returns:
a list of `tf.placeholder`, to be converted to :class:`InputDesc`.
"""
"""
raise
NotImplementedError
()
def
build_graph
(
self
,
*
args
):
def
build_graph
(
self
,
*
args
):
"""
"""
...
@@ -93,13 +131,12 @@ class ModelDescBase(object):
...
@@ -93,13 +131,12 @@ class ModelDescBase(object):
By default it will call :meth:`_build_graph` with a list of input tensors.
By default it will call :meth:`_build_graph` with a list of input tensors.
Args:
Args:
args ([tf.Tensor]): tensors that matches the list of
args ([tf.Tensor]): tensors that matches the list of inputs defined by ``inputs()``.
:class:`InputDesc` defined by ``_get_inputs``.
Returns:
Returns:
In general it returns nothing, but a subclass (e.g.
In general it returns nothing, but a subclass (e.g.
:class:`ModelDesc` may require it to return necessary information
:class:`ModelDesc`
)
may require it to return necessary information
to build the trainer.
(e.g. cost)
to build the trainer.
"""
"""
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
arg
=
args
[
0
]
arg
=
args
[
0
]
...
@@ -122,7 +159,8 @@ class ModelDescBase(object):
...
@@ -122,7 +159,8 @@ class ModelDescBase(object):
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
"""
"""
This is an old interface which takes a list of tensors, instead of positional arguments.
This is an alternative interface which takes a list of tensors, instead of positional arguments.
By default :meth:`build_graph` will call this method.
"""
"""
pass
pass
...
...
tensorpack/input_source/input_source.py
View file @
39fa4656
...
@@ -41,7 +41,7 @@ class PlaceholderInput(InputSource):
...
@@ -41,7 +41,7 @@ class PlaceholderInput(InputSource):
Just produce placeholders as input tensors.
Just produce placeholders as input tensors.
"""
"""
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder
()
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
v
.
build_placeholder
_reuse
()
for
v
in
inputs
]
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
return
self
.
_all_placehdrs
return
self
.
_all_placehdrs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment