Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
adf51f22
Commit
adf51f22
authored
Jul 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Allow remapping on every InputSource, therefore remove the old 'names' options
parent
20d7fe7f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
83 deletions
+65
-83
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+21
-13
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+33
-59
tensorpack/train/utils.py
tensorpack/train/utils.py
+11
-11
No files found.
tensorpack/callbacks/inference_runner.py
View file @
adf51f22
...
...
@@ -16,7 +16,8 @@ from ..utils import logger, get_tqdm_kwargs
from
..dataflow
import
DataFlow
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..train.input_source
import
FeedInput
,
DataParallelFeedInput
,
FeedfreeInput
from
..train.input_source
import
(
FeedInput
,
DataParallelFeedInput
,
FeedfreeInput
,
InputSource
)
from
..predict
import
PredictorTowerBuilder
from
.base
import
Callback
...
...
@@ -128,16 +129,15 @@ class InferenceRunner(InferenceRunnerBase):
:class:`DataFlow`.
"""
def
__init__
(
self
,
ds
,
infs
,
input_names
=
None
,
extra_hooks
=
None
):
def
__init__
(
self
,
input
,
infs
,
extra_hooks
=
None
):
"""
Args:
ds (DataFlow):
the DataFlow to run inferencer on.
input (FeedInput or DataFlow): the FeedInput, or
the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
input
=
FeedInput
(
ds
,
input_names
)
if
isinstance
(
input
,
DataFlow
):
input
=
FeedInput
(
input
)
assert
isinstance
(
input
,
FeedInput
),
input
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
prefix
=
''
,
extra_hooks
=
extra_hooks
)
...
...
@@ -158,7 +158,6 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
Args:
input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
...
...
@@ -180,11 +179,20 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
class
DataParallelInferenceRunner
(
InferenceRunnerBase
):
def
__init__
(
self
,
ds
,
infs
,
gpus
,
input_names
=
None
):
self
.
_tower_names
=
[
TowerContext
.
get_predict_tower_name
(
k
)
for
k
in
range
(
len
(
gpus
))]
input
=
DataParallelFeedInput
(
ds
,
self
.
_tower_names
,
input_names
=
input_names
)
"""
Not tested. Don't use.
"""
# TODO some scripts to test
def
__init__
(
self
,
input
,
infs
,
gpus
):
"""
Args:
input (DataParallelFeedInput or DataFlow)
"""
if
isinstance
(
input
,
DataFlow
):
tower_names
=
[
TowerContext
.
get_predict_tower_name
(
k
)
for
k
in
range
(
len
(
gpus
))]
input
=
DataParallelFeedInput
(
input
,
tower_names
)
assert
isinstance
(
input
,
InputSource
),
input
super
(
DataParallelInferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
self
.
_gpus
=
gpus
...
...
tensorpack/train/input_source.py
View file @
adf51f22
...
...
@@ -14,7 +14,7 @@ from abc import ABCMeta, abstractmethod
import
six
from
six.moves
import
range
,
zip
from
.utils
import
get_
placeholders
_by_names
,
get_tensors_inputs
from
.utils
import
get_
sublist
_by_names
,
get_tensors_inputs
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
...
...
@@ -30,7 +30,7 @@ __all__ = ['InputSource',
'FeedfreeInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper'
,
'
ReorderInputS
ource'
]
'StagingInputWrapper'
,
'
remap_input_s
ource'
]
@
six
.
add_metaclass
(
ABCMeta
)
...
...
@@ -82,29 +82,19 @@ class InputSource(object):
class
FeedInput
(
InputSource
):
""" Input by iterating over a DataFlow and feed datapoints. """
def
__init__
(
self
,
ds
,
input_names
=
None
):
def
__init__
(
self
,
ds
):
"""
Args:
ds (DataFlow): the input DataFlow.
input_names (list[str]): input names this DataFlow maps to
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
if
input_names
is
not
None
:
assert
isinstance
(
input_names
,
(
list
,
tuple
)),
input_names
self
.
ds
=
ds
self
.
_input_names
=
input_names
def
size
(
self
):
return
self
.
ds
.
size
()
def
setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
if
self
.
_input_names
is
None
:
self
.
_placehdrs_to_feed
=
self
.
_all_placehdrs
else
:
self
.
_placehdrs_to_feed
=
get_placeholders_by_names
(
self
.
_all_placehdrs
,
self
.
_input_names
)
self
.
reset_state
()
def
reset_state
(
self
):
...
...
@@ -117,37 +107,25 @@ class FeedInput(InputSource):
def
next_feed
(
self
):
dp
=
next
(
self
.
data_producer
)
return
dict
(
zip
(
self
.
_placehdrs_to_feed
,
dp
))
assert
len
(
dp
)
==
len
(
self
.
_all_placehdrs
),
"[FeedInput] datapoints and inputs are of different length!"
return
dict
(
zip
(
self
.
_all_placehdrs
,
dp
))
class
DataParallelFeedInput
(
FeedInput
):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
def
__init__
(
self
,
ds
,
tower_names
,
input_names
=
None
):
super
(
DataParallelFeedInput
,
self
)
.
__init__
(
ds
,
input_names
)
def
__init__
(
self
,
ds
,
tower_names
):
super
(
DataParallelFeedInput
,
self
)
.
__init__
(
ds
)
self
.
_tower_names
=
tower_names
self
.
_nr_tower
=
len
(
tower_names
)
def
setup
(
self
,
inputs
):
self
.
_placehdrs_per_tower
=
[]
self
.
_feed_placehdrs_per_tower
=
[]
for
tname
in
self
.
_tower_names
:
# build a list of placeholders for each tower
self
.
_placehdrs_per_tower
.
append
(
[
v
.
build_placeholder
(
prefix
=
tname
+
'/'
)
for
v
in
inputs
])
# apply input mapping and store results in feed_placehdrs_per_tower
if
self
.
_input_names
is
None
:
self
.
_feed_placehdrs_per_tower
=
self
.
_placehdrs_per_tower
else
:
for
phdrs
,
tname
in
zip
(
self
.
_placehdrs_per_tower
,
self
.
_tower_names
):
input_names
=
[
tname
+
'/'
+
n
for
n
in
self
.
_input_names
]
# input_names to be used for this specific tower
self
.
_feed_placehdrs_per_tower
.
append
(
get_placeholders_by_names
(
phdrs
,
input_names
))
print
(
self
.
_feed_placehdrs_per_tower
[
-
1
])
self
.
reset_state
()
def
get_input_tensors
(
self
):
...
...
@@ -165,7 +143,7 @@ class DataParallelFeedInput(FeedInput):
feed
=
{}
for
t
in
range
(
cnt
):
dp
=
next
(
self
.
data_producer
)
f
=
dict
(
zip
(
self
.
_
feed_
placehdrs_per_tower
[
t
],
dp
))
f
=
dict
(
zip
(
self
.
_placehdrs_per_tower
[
t
],
dp
))
feed
.
update
(
f
)
return
feed
...
...
@@ -175,7 +153,6 @@ class FeedfreeInput(InputSource):
e.g. by queue or other operations. """
def
reset_state
(
self
):
# TODO no state to reset
pass
def
next_feed
(
self
):
...
...
@@ -226,38 +203,31 @@ class QueueInput(FeedfreeInput):
And the model receives dequeued tensors.
"""
def
__init__
(
self
,
ds
,
queue
=
None
,
names
=
None
):
def
__init__
(
self
,
ds
,
queue
=
None
):
"""
Args:
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50.
names(list[str]): list of input names corresponding to the dataflow.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
queue
=
queue
self
.
ds
=
ds
self
.
_names
=
names
def
size
(
self
):
return
self
.
ds
.
size
()
# TODO use input data mapping. not all placeholders are needed
def
setup
(
self
,
inputs
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
if
self
.
_names
is
None
:
self
.
_queue_feedpoint
=
self
.
input_placehdrs
else
:
self
.
_queue_feedpoint
=
get_placeholders_by_names
(
self
.
input_placehdrs
,
self
.
_names
)
assert
len
(
self
.
_queue_feedpoint
)
>
0
,
\
self
.
_input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
assert
len
(
self
.
_input_placehdrs
)
>
0
,
\
"QueueInput has to be used with some inputs!"
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
_
queue_feedpoint
],
50
,
[
x
.
dtype
for
x
in
self
.
_
input_placehdrs
],
name
=
'input_queue'
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_
queue_feedpoint
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_
input_placehdrs
)
def
get_callbacks
(
self
):
cb
=
StartProcOrThread
(
self
.
thread
)
...
...
@@ -269,13 +239,10 @@ class QueueInput(FeedfreeInput):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
_
queue_feedpoint
)
for
qv
,
v
in
zip
(
ret
,
self
.
_
queue_feedpoint
):
assert
len
(
ret
)
==
len
(
self
.
_
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
_
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
if
self
.
_names
is
None
:
return
ret
else
:
return
get_tensors_inputs
(
self
.
input_placehdrs
,
ret
,
self
.
_names
)
return
ret
class
BatchQueueInput
(
QueueInput
):
...
...
@@ -351,8 +318,6 @@ class TensorInput(FeedfreeInput):
get_tensor_fn: a function which returns a list of input tensors
when called. It will be called under a TowerContext.
size(int): size of this input. Use None to leave it undefined.
input_names (list[str]): input names the tensors maps to. Defaults
to be all the inputs of the model.
"""
self
.
get_tensor_fn
=
get_tensor_fn
if
size
is
not
None
:
...
...
@@ -397,7 +362,6 @@ class DummyConstantInput(TensorInput):
self
.
inputs_desc
=
inputs
# TODO doesn't support remapping
class
ZMQInput
(
TensorInput
):
"""
Not well implemented yet. Don't use.
...
...
@@ -511,29 +475,35 @@ class StagingInputWrapper(FeedfreeInput):
return
tf
.
group
(
*
all_outputs
)
class
ReorderInputSource
(
FeedfreeInput
):
# TODO dynamically generate inheritance
# TODO make it a function, not a class
class
remap_input_source
(
FeedInput
,
FeedfreeInput
):
"""
When an InputSource only maps to a subset of the InputDesc of the model,
wrap it with :class:`ReorderInputSource`.
When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
"""
def
__init__
(
self
,
input
,
names
):
"""
Args:
input(
TensorInput): a TensorInput
, whose tensors will get mapped.
input(
InputSource): a :class:`InputSource`
, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
assert
isinstance
(
input
,
TensorInput
),
input
assert
isinstance
(
input
,
InputSource
),
input
self
.
_input
=
input
assert
isinstance
(
names
,
(
list
,
tuple
)),
names
self
.
_names
=
names
self
.
_names
=
tuple
(
names
)
def
size
(
self
):
return
self
.
_input
.
size
()
def
setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_input
.
setup
(
inputs
)
inputs_subset
=
get_sublist_by_names
(
inputs
,
self
.
_names
)
self
.
_input
.
setup
(
inputs_subset
)
def
get_callbacks
(
self
):
return
self
.
_input
.
get_callbacks
()
...
...
@@ -541,7 +511,11 @@ class ReorderInputSource(FeedfreeInput):
def
reset_state
(
self
):
self
.
_input
.
reset_state
()
def
next_feed
(
self
):
return
self
.
_input
.
next_feed
()
def
get_input_tensors
(
self
):
ret
=
self
.
_input
.
get_input_tensors
()
assert
len
(
ret
)
==
len
(
self
.
_names
)
return
get_tensors_inputs
(
self
.
_all_placehdrs
,
ret
,
self
.
_names
)
tensorpack/train/utils.py
View file @
adf51f22
...
...
@@ -8,14 +8,11 @@ from six.moves import zip
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
__all__
=
[
'get_tensors_inputs'
,
'get_
placeholders
_by_names'
]
__all__
=
[
'get_tensors_inputs'
,
'get_
sublist
_by_names'
]
def
get_tensors_inputs
(
placeholders
,
tensors
,
names
):
"""
Quite often we want to `build_graph()` with normal tensors
(rather than placeholders).
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
...
...
@@ -41,19 +38,22 @@ def get_tensors_inputs(placeholders, tensors, names):
return
ret
def
get_
placeholders_by_names
(
placeholders
,
names
):
def
get_
sublist_by_names
(
lst
,
names
):
"""
Args:
lst (list): list of objects with "name" property.
Returns:
list
[Tensor]: a sublist of placeholder
s, matching names
list
: a sublist of object
s, matching names
"""
placeholder_names
=
[
p
.
name
for
p
in
placeholders
]
orig_names
=
[
p
.
name
for
p
in
lst
]
ret
=
[]
for
name
in
names
:
tensorname
=
get_op_tensor_name
(
name
)[
1
]
try
:
idx
=
placeholder_names
.
index
(
tensor
name
)
idx
=
orig_names
.
index
(
name
)
except
ValueError
:
logger
.
error
(
"Name {} is not a model input!"
.
format
(
tensorname
))
logger
.
error
(
"Name {} doesn't appear in lst {}!"
.
format
(
name
,
str
(
orig_names
)))
raise
ret
.
append
(
placeholders
[
idx
])
ret
.
append
(
lst
[
idx
])
return
ret
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment