Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a266459e
Commit
a266459e
authored
Jul 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use InputDesc inside input_source, instead of ModelDesc
parent
8dcf454d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
49 deletions
+67
-49
examples/cifar-convnet.py
examples/cifar-convnet.py
+2
-2
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+37
-27
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+26
-18
tensorpack/train/simple.py
tensorpack/train/simple.py
+2
-2
No files found.
examples/cifar-convnet.py
View file @
a266459e
...
@@ -29,8 +29,8 @@ class Model(ModelDesc):
...
@@ -29,8 +29,8 @@ class Model(ModelDesc):
self
.
cifar_classnum
=
cifar_classnum
self
.
cifar_classnum
=
cifar_classnum
def
_get_inputs
(
self
):
def
_get_inputs
(
self
):
return
[
InputDesc
(
tf
.
float32
,
[
None
,
30
,
30
,
3
]
,
'input'
),
return
[
InputDesc
(
tf
.
float32
,
(
None
,
30
,
30
,
3
)
,
'input'
),
InputDesc
(
tf
.
int32
,
[
None
]
,
'label'
)
InputDesc
(
tf
.
int32
,
(
None
,)
,
'label'
)
]
]
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
...
...
tensorpack/models/model_desc.py
View file @
a266459e
...
@@ -6,14 +6,12 @@
...
@@ -6,14 +6,12 @@
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
namedtuple
from
collections
import
namedtuple
import
tensorflow
as
tf
import
tensorflow
as
tf
import
pickle
import
six
import
six
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
.regularize
import
regularize_cost_from_collection
from
.regularize
import
regularize_cost_from_collection
__all__
=
[
'InputDesc'
,
'
InputVar'
,
'
ModelDesc'
]
__all__
=
[
'InputDesc'
,
'ModelDesc'
]
class
InputDesc
(
class
InputDesc
(
...
@@ -24,23 +22,36 @@ class InputDesc(
...
@@ -24,23 +22,36 @@ class InputDesc(
input source.
input source.
"""
"""
def
dumps
(
self
):
_cached_placeholder
=
None
"""
Returns:
str: serialized string
"""
return
pickle
.
dumps
(
self
)
@
staticmethod
def
__init__
(
self
,
type
,
shape
,
name
):
def
loads
(
buf
):
"""
"""
Args:
Args:
buf (str): serialized string
type (tf.DType):
shape (tuple):
Returns:
name (str):
InputDesc:
"""
"""
shape
=
tuple
(
shape
)
# has to be tuple for self to be hashable
return
pickle
.
loads
(
buf
)
super
(
InputDesc
,
self
)
.
__init__
(
type
,
shape
,
name
)
# TODO in serialization, skip _cached_placeholder
# def dumps(self):
# """
# Returns:
# str: serialized string
# """
# return pickle.dumps(self)
# @staticmethod
# def loads(buf):
# """
# Args:
# buf (str): serialized string
# Returns:
# InputDesc:
# """
# return pickle.loads(buf)
def
build_placeholder
(
self
,
prefix
=
''
):
def
build_placeholder
(
self
,
prefix
=
''
):
"""
"""
...
@@ -53,11 +64,13 @@ class InputDesc(
...
@@ -53,11 +64,13 @@ class InputDesc(
tf.Tensor:
tf.Tensor:
"""
"""
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
urn
tf
.
placeholder
(
ret
=
tf
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
self
.
type
,
shape
=
self
.
shape
,
name
=
prefix
+
self
.
name
)
name
=
prefix
+
self
.
name
)
if
prefix
==
''
and
self
.
_cached_placeholder
is
None
:
self
.
_cached_placeholder
=
ret
return
ret
# TODO cache results from build_placeholder, and skip it in serialization
@
memoized
@
memoized
def
build_placeholder_reuse
(
self
):
def
build_placeholder_reuse
(
self
):
"""
"""
...
@@ -66,21 +79,18 @@ class InputDesc(
...
@@ -66,21 +79,18 @@ class InputDesc(
Returns:
Returns:
tf.Tensor:
tf.Tensor:
"""
"""
if
self
.
_cached_placeholder
is
not
None
:
return
self
.
_cached_placeholder
return
self
.
build_placeholder
()
return
self
.
build_placeholder
()
class
InputVar
(
InputDesc
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
logger
.
warn
(
"[Deprecated] InputVar was renamed to InputDesc!"
)
super
(
InputVar
,
self
)
.
__init__
(
*
args
,
**
kwargs
)
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDesc
(
object
):
class
ModelDesc
(
object
):
""" Base class for a model description.
""" Base class for a model description.
"""
"""
# inputs:
# inputs:
# TODO remove this method?
@
memoized
@
memoized
def
get_reused_placehdrs
(
self
):
def
get_reused_placehdrs
(
self
):
"""
"""
...
@@ -89,7 +99,7 @@ class ModelDesc(object):
...
@@ -89,7 +99,7 @@ class ModelDesc(object):
Returns:
Returns:
list[tf.Tensor]: the list of input placeholders in the graph.
list[tf.Tensor]: the list of input placeholders in the graph.
"""
"""
return
self
.
build_placeholders
()
return
[
v
.
build_placeholder_reuse
()
for
v
in
self
.
get_inputs_desc
()]
def
build_placeholders
(
self
,
prefix
=
''
):
def
build_placeholders
(
self
,
prefix
=
''
):
"""
"""
...
@@ -99,7 +109,7 @@ class ModelDesc(object):
...
@@ -99,7 +109,7 @@ class ModelDesc(object):
Returns:
Returns:
list[tf.Tensor]: the list of built placeholders.
list[tf.Tensor]: the list of built placeholders.
"""
"""
inputs
=
self
.
_get_inputs
()
inputs
=
self
.
get_inputs_desc
()
ret
=
[]
ret
=
[]
for
v
in
inputs
:
for
v
in
inputs
:
ret
.
append
(
v
.
build_placeholder
(
prefix
))
ret
.
append
(
v
.
build_placeholder
(
prefix
))
...
...
tensorpack/train/input_source.py
View file @
a266459e
...
@@ -83,7 +83,8 @@ class FeedInput(InputSource):
...
@@ -83,7 +83,8 @@ class FeedInput(InputSource):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
self
.
_all_placehdrs
=
model
.
get_reused_placehdrs
()
inputs
=
model
.
get_inputs_desc
()
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
if
self
.
_input_names
is
None
:
if
self
.
_input_names
is
None
:
self
.
_placehdrs_to_feed
=
self
.
_all_placehdrs
self
.
_placehdrs_to_feed
=
self
.
_all_placehdrs
else
:
else
:
...
@@ -115,13 +116,13 @@ class DataParallelFeedInput(FeedInput):
...
@@ -115,13 +116,13 @@ class DataParallelFeedInput(FeedInput):
self
.
_nr_tower
=
len
(
tower_names
)
self
.
_nr_tower
=
len
(
tower_names
)
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
inputs
=
model
.
get_inputs_desc
()
self
.
_placehdrs_per_tower
=
[]
self
.
_placehdrs_per_tower
=
[]
self
.
_feed_placehdrs_per_tower
=
[]
self
.
_feed_placehdrs_per_tower
=
[]
for
tname
in
self
.
_tower_names
:
for
tname
in
self
.
_tower_names
:
# build a list of placeholders for each tower
# build a list of placeholders for each tower
self
.
_placehdrs_per_tower
.
append
(
self
.
_placehdrs_per_tower
.
append
(
model
.
build_placeholders
(
[
v
.
build_placeholder
(
prefix
=
tname
+
'/'
)
for
v
in
inputs
])
prefix
=
tname
+
'/'
))
# apply input mapping and store results in feed_placehdrs_per_tower
# apply input mapping and store results in feed_placehdrs_per_tower
if
self
.
_input_names
is
None
:
if
self
.
_input_names
is
None
:
...
@@ -232,7 +233,8 @@ class QueueInput(FeedfreeInput):
...
@@ -232,7 +233,8 @@ class QueueInput(FeedfreeInput):
# TODO use input data mapping. not all placeholders are needed
# TODO use input data mapping. not all placeholders are needed
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
inputs
=
model
.
get_inputs_desc
()
self
.
input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
if
self
.
_names
is
None
:
if
self
.
_names
is
None
:
self
.
_queue_feedpoint
=
self
.
input_placehdrs
self
.
_queue_feedpoint
=
self
.
input_placehdrs
else
:
else
:
...
@@ -289,7 +291,8 @@ class BatchQueueInput(FeedfreeInput):
...
@@ -289,7 +291,8 @@ class BatchQueueInput(FeedfreeInput):
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
inputs
=
model
.
get_inputs_desc
()
self
.
input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"BatchQueueInput has to be used with some InputDesc!"
"BatchQueueInput has to be used with some InputDesc!"
...
@@ -377,39 +380,42 @@ class DummyConstantInput(TensorInput):
...
@@ -377,39 +380,42 @@ class DummyConstantInput(TensorInput):
tlist
=
[]
tlist
=
[]
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
assert
ctx
is
not
None
assert
ctx
is
not
None
assert
len
(
self
.
shapes
)
==
len
(
self
.
input
_placehdrs
)
assert
len
(
self
.
shapes
)
==
len
(
self
.
input
s_desc
)
for
idx
,
p
in
enumerate
(
self
.
input
_placehdrs
):
for
idx
,
p
in
enumerate
(
self
.
input
s_desc
):
tlist
.
append
(
tf
.
constant
(
tlist
.
append
(
tf
.
constant
(
0
,
dtype
=
p
.
d
type
,
0
,
dtype
=
p
.
type
,
name
=
'dummy-{}-{}'
.
format
(
p
.
op
.
name
,
ctx
.
index
),
name
=
'dummy-{}-{}'
.
format
(
p
.
name
,
ctx
.
index
),
shape
=
self
.
shapes
[
idx
]))
shape
=
self
.
shapes
[
idx
]))
return
tlist
return
tlist
super
(
DummyConstantInput
,
self
)
.
__init__
(
fn
)
super
(
DummyConstantInput
,
self
)
.
__init__
(
fn
)
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
self
.
input
_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
input
s_desc
=
model
.
get_inputs_desc
()
# TODO doesn't support remapping
# TODO doesn't support remapping
class
ZMQInput
(
TensorInput
):
class
ZMQInput
(
TensorInput
):
"""
Not well implemented yet. Don't use.
"""
def
__init__
(
self
,
endpoint
):
def
__init__
(
self
,
endpoint
):
self
.
_endpoint
=
endpoint
self
.
_endpoint
=
endpoint
from
tensorpack.user_ops
import
zmq_recv
from
tensorpack.user_ops
import
zmq_recv
def
fn
():
def
fn
():
ret
=
zmq_recv
(
self
.
_endpoint
,
[
x
.
dtype
for
x
in
self
.
input
_placehdrs
])
ret
=
zmq_recv
(
self
.
_endpoint
,
[
x
.
dtype
for
x
in
self
.
input
s_desc
])
if
isinstance
(
ret
,
tf
.
Tensor
):
if
isinstance
(
ret
,
tf
.
Tensor
):
ret
=
[
ret
]
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input
_placehdrs
)
assert
len
(
ret
)
==
len
(
self
.
input
s_desc
)
for
qv
,
v
in
zip
(
ret
,
self
.
input
_placehdrs
):
for
qv
,
v
in
zip
(
ret
,
self
.
input
s_desc
):
qv
.
set_shape
(
v
.
get_shape
()
)
qv
.
set_shape
(
v
.
shape
)
return
ret
return
ret
super
(
ZMQInput
,
self
)
.
__init__
(
fn
)
super
(
ZMQInput
,
self
)
.
__init__
(
fn
)
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
self
.
input
_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
input
s_desc
=
model
.
get_inputs_desc
()
assert
len
(
self
.
input
_placehdrs
)
>
0
,
\
assert
len
(
self
.
input
s_desc
)
>
0
,
\
"ZMQInput has to be used with InputDesc!"
"ZMQInput has to be used with InputDesc!"
...
@@ -522,11 +528,13 @@ class ReorderInputSource(FeedfreeInput):
...
@@ -522,11 +528,13 @@ class ReorderInputSource(FeedfreeInput):
return
self
.
_input
.
size
()
return
self
.
_input
.
size
()
def
setup
(
self
,
model
):
def
setup
(
self
,
model
):
self
.
_all_placehdrs
=
model
.
get_reused_placehdrs
()
inputs
=
model
.
get_inputs_desc
()
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_input
.
setup
(
model
)
self
.
_input
.
setup
(
model
)
def
setup_training
(
self
,
trainer
):
def
setup_training
(
self
,
trainer
):
self
.
_all_placehdrs
=
trainer
.
model
.
get_reused_placehdrs
()
inputs
=
trainer
.
model
.
get_inputs_desc
()
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_input
.
setup_training
(
trainer
)
self
.
_input
.
setup_training
(
trainer
)
def
reset_state
(
self
):
def
reset_state
(
self
):
...
...
tensorpack/train/simple.py
View file @
a266459e
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File:
trainer
.py
# File:
simple
.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
...
@@ -37,7 +37,7 @@ class SimpleTrainer(Trainer):
...
@@ -37,7 +37,7 @@ class SimpleTrainer(Trainer):
def
_setup
(
self
):
def
_setup
(
self
):
self
.
_input_source
.
setup_training
(
self
)
self
.
_input_source
.
setup_training
(
self
)
model
=
self
.
model
model
=
self
.
model
self
.
inputs
=
model
.
get_reused_placehd
rs
()
self
.
inputs
=
self
.
_input_source
.
get_input_tenso
rs
()
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
self
.
inputs
)
model
.
build_graph
(
self
.
inputs
)
cost_var
=
model
.
get_cost
()
cost_var
=
model
.
get_cost
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment