Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a266459e
Commit
a266459e
authored
Jul 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use InputDesc inside input_source, instead of ModelDesc
parent
8dcf454d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
49 deletions
+67
-49
examples/cifar-convnet.py
examples/cifar-convnet.py
+2
-2
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+37
-27
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+26
-18
tensorpack/train/simple.py
tensorpack/train/simple.py
+2
-2
No files found.
examples/cifar-convnet.py
View file @
a266459e
...
...
@@ -29,8 +29,8 @@ class Model(ModelDesc):
self
.
cifar_classnum
=
cifar_classnum
def
_get_inputs
(
self
):
return
[
InputDesc
(
tf
.
float32
,
[
None
,
30
,
30
,
3
]
,
'input'
),
InputDesc
(
tf
.
int32
,
[
None
]
,
'label'
)
return
[
InputDesc
(
tf
.
float32
,
(
None
,
30
,
30
,
3
)
,
'input'
),
InputDesc
(
tf
.
int32
,
(
None
,)
,
'label'
)
]
def
_build_graph
(
self
,
inputs
):
...
...
tensorpack/models/model_desc.py
View file @
a266459e
...
...
@@ -6,14 +6,12 @@
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
namedtuple
import
tensorflow
as
tf
import
pickle
import
six
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
.regularize
import
regularize_cost_from_collection
__all__
=
[
'InputDesc'
,
'
InputVar'
,
'
ModelDesc'
]
__all__
=
[
'InputDesc'
,
'ModelDesc'
]
class
InputDesc
(
...
...
@@ -24,23 +22,36 @@ class InputDesc(
input source.
"""
def
dumps
(
self
):
"""
Returns:
str: serialized string
"""
return
pickle
.
dumps
(
self
)
_cached_placeholder
=
None
@
staticmethod
def
loads
(
buf
):
def
__init__
(
self
,
type
,
shape
,
name
):
"""
Args:
buf (str): serialized string
Returns:
InputDesc:
"""
return
pickle
.
loads
(
buf
)
type (tf.DType):
shape (tuple):
name (str):
"""
shape
=
tuple
(
shape
)
# has to be tuple for self to be hashable
super
(
InputDesc
,
self
)
.
__init__
(
type
,
shape
,
name
)
# TODO in serialization, skip _cached_placeholder
# def dumps(self):
# """
# Returns:
# str: serialized string
# """
# return pickle.dumps(self)
# @staticmethod
# def loads(buf):
# """
# Args:
# buf (str): serialized string
# Returns:
# InputDesc:
# """
# return pickle.loads(buf)
def
build_placeholder
(
self
,
prefix
=
''
):
"""
...
...
@@ -53,11 +64,13 @@ class InputDesc(
tf.Tensor:
"""
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
urn
tf
.
placeholder
(
ret
=
tf
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
name
=
prefix
+
self
.
name
)
if
prefix
==
''
and
self
.
_cached_placeholder
is
None
:
self
.
_cached_placeholder
=
ret
return
ret
# TODO cache results from build_placeholder, and skip it in serialization
@
memoized
def
build_placeholder_reuse
(
self
):
"""
...
...
@@ -66,21 +79,18 @@ class InputDesc(
Returns:
tf.Tensor:
"""
if
self
.
_cached_placeholder
is
not
None
:
return
self
.
_cached_placeholder
return
self
.
build_placeholder
()
class
InputVar
(
InputDesc
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
logger
.
warn
(
"[Deprecated] InputVar was renamed to InputDesc!"
)
super
(
InputVar
,
self
)
.
__init__
(
*
args
,
**
kwargs
)
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDesc
(
object
):
""" Base class for a model description.
"""
# inputs:
# TODO remove this method?
@
memoized
def
get_reused_placehdrs
(
self
):
"""
...
...
@@ -89,7 +99,7 @@ class ModelDesc(object):
Returns:
list[tf.Tensor]: the list of input placeholders in the graph.
"""
return
self
.
build_placeholders
()
return
[
v
.
build_placeholder_reuse
()
for
v
in
self
.
get_inputs_desc
()]
def
build_placeholders
(
self
,
prefix
=
''
):
"""
...
...
@@ -99,7 +109,7 @@ class ModelDesc(object):
Returns:
list[tf.Tensor]: the list of built placeholders.
"""
inputs
=
self
.
_get_inputs
()
inputs
=
self
.
get_inputs_desc
()
ret
=
[]
for
v
in
inputs
:
ret
.
append
(
v
.
build_placeholder
(
prefix
))
...
...
tensorpack/train/input_source.py
View file @
a266459e
...
...
@@ -83,7 +83,8 @@ class FeedInput(InputSource):
return
self
.
ds
.
size
()
def
setup
(
self
,
model
):
self
.
_all_placehdrs
=
model
.
get_reused_placehdrs
()
inputs
=
model
.
get_inputs_desc
()
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
if
self
.
_input_names
is
None
:
self
.
_placehdrs_to_feed
=
self
.
_all_placehdrs
else
:
...
...
@@ -115,13 +116,13 @@ class DataParallelFeedInput(FeedInput):
self
.
_nr_tower
=
len
(
tower_names
)
def
setup
(
self
,
model
):
inputs
=
model
.
get_inputs_desc
()
self
.
_placehdrs_per_tower
=
[]
self
.
_feed_placehdrs_per_tower
=
[]
for
tname
in
self
.
_tower_names
:
# build a list of placeholders for each tower
self
.
_placehdrs_per_tower
.
append
(
model
.
build_placeholders
(
prefix
=
tname
+
'/'
))
[
v
.
build_placeholder
(
prefix
=
tname
+
'/'
)
for
v
in
inputs
])
# apply input mapping and store results in feed_placehdrs_per_tower
if
self
.
_input_names
is
None
:
...
...
@@ -232,7 +233,8 @@ class QueueInput(FeedfreeInput):
# TODO use input data mapping. not all placeholders are needed
def
setup
(
self
,
model
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
inputs
=
model
.
get_inputs_desc
()
self
.
input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
if
self
.
_names
is
None
:
self
.
_queue_feedpoint
=
self
.
input_placehdrs
else
:
...
...
@@ -289,7 +291,8 @@ class BatchQueueInput(FeedfreeInput):
def
setup
(
self
,
model
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
inputs
=
model
.
get_inputs_desc
()
self
.
input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"BatchQueueInput has to be used with some InputDesc!"
...
...
@@ -377,39 +380,42 @@ class DummyConstantInput(TensorInput):
tlist
=
[]
ctx
=
get_current_tower_context
()
assert
ctx
is
not
None
assert
len
(
self
.
shapes
)
==
len
(
self
.
input
_placehdrs
)
for
idx
,
p
in
enumerate
(
self
.
input
_placehdrs
):
assert
len
(
self
.
shapes
)
==
len
(
self
.
input
s_desc
)
for
idx
,
p
in
enumerate
(
self
.
input
s_desc
):
tlist
.
append
(
tf
.
constant
(
0
,
dtype
=
p
.
d
type
,
name
=
'dummy-{}-{}'
.
format
(
p
.
op
.
name
,
ctx
.
index
),
0
,
dtype
=
p
.
type
,
name
=
'dummy-{}-{}'
.
format
(
p
.
name
,
ctx
.
index
),
shape
=
self
.
shapes
[
idx
]))
return
tlist
super
(
DummyConstantInput
,
self
)
.
__init__
(
fn
)
def
setup
(
self
,
model
):
self
.
input
_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
input
s_desc
=
model
.
get_inputs_desc
()
# TODO doesn't support remapping
class
ZMQInput
(
TensorInput
):
"""
Not well implemented yet. Don't use.
"""
def
__init__
(
self
,
endpoint
):
self
.
_endpoint
=
endpoint
from
tensorpack.user_ops
import
zmq_recv
def
fn
():
ret
=
zmq_recv
(
self
.
_endpoint
,
[
x
.
dtype
for
x
in
self
.
input
_placehdrs
])
ret
=
zmq_recv
(
self
.
_endpoint
,
[
x
.
dtype
for
x
in
self
.
input
s_desc
])
if
isinstance
(
ret
,
tf
.
Tensor
):
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input
_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input
_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
()
)
assert
len
(
ret
)
==
len
(
self
.
input
s_desc
)
for
qv
,
v
in
zip
(
ret
,
self
.
input
s_desc
):
qv
.
set_shape
(
v
.
shape
)
return
ret
super
(
ZMQInput
,
self
)
.
__init__
(
fn
)
def
setup
(
self
,
model
):
self
.
input
_placehdrs
=
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input
_placehdrs
)
>
0
,
\
self
.
input
s_desc
=
model
.
get_inputs_desc
()
assert
len
(
self
.
input
s_desc
)
>
0
,
\
"ZMQInput has to be used with InputDesc!"
...
...
@@ -522,11 +528,13 @@ class ReorderInputSource(FeedfreeInput):
return
self
.
_input
.
size
()
def
setup
(
self
,
model
):
self
.
_all_placehdrs
=
model
.
get_reused_placehdrs
()
inputs
=
model
.
get_inputs_desc
()
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_input
.
setup
(
model
)
def
setup_training
(
self
,
trainer
):
self
.
_all_placehdrs
=
trainer
.
model
.
get_reused_placehdrs
()
inputs
=
trainer
.
model
.
get_inputs_desc
()
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_input
.
setup_training
(
trainer
)
def
reset_state
(
self
):
...
...
tensorpack/train/simple.py
View file @
a266459e
# -*- coding: UTF-8 -*-
# File:
trainer
.py
# File:
simple
.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
...
@@ -37,7 +37,7 @@ class SimpleTrainer(Trainer):
def
_setup
(
self
):
self
.
_input_source
.
setup_training
(
self
)
model
=
self
.
model
self
.
inputs
=
model
.
get_reused_placehd
rs
()
self
.
inputs
=
self
.
_input_source
.
get_input_tenso
rs
()
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
self
.
inputs
)
cost_var
=
model
.
get_cost
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment