Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ddebb23c
Commit
ddebb23c
authored
May 26, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move input_names mapping to InputSource
parent
9626ebd8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
34 deletions
+60
-34
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+6
-26
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+54
-8
No files found.
tensorpack/callbacks/inference_runner.py
View file @
ddebb23c
...
@@ -16,8 +16,7 @@ from ..utils import logger, get_tqdm_kwargs
...
@@ -16,8 +16,7 @@ from ..utils import logger, get_tqdm_kwargs
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
from
..train.input_source
import
TensorInput
,
FeedInput
,
DataParallelFeedInput
from
..train.input_source
import
FeedInput
,
DataParallelFeedInput
,
FeedfreeInput
from
..train.utils
import
get_tensors_inputs
from
..predict
import
PredictorTowerBuilder
from
..predict
import
PredictorTowerBuilder
from
.base
import
Callback
from
.base
import
Callback
...
@@ -89,8 +88,7 @@ class InferenceRunnerBase(Callback):
...
@@ -89,8 +88,7 @@ class InferenceRunnerBase(Callback):
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
self
.
_predict_tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
def
fn
(
_
):
def
fn
(
_
):
in_tensors
=
self
.
_find_input_tensors
()
in_tensors
=
self
.
_input_source
.
get_input_tensors
()
assert
isinstance
(
in_tensors
,
(
list
,
tuple
)),
in_tensors
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
...
@@ -105,9 +103,6 @@ class InferenceRunnerBase(Callback):
...
@@ -105,9 +103,6 @@ class InferenceRunnerBase(Callback):
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
return
get_tensor_fn
(
placeholder_names
,
names
,
self
.
_predict_tower_id
,
prefix
=
self
.
_prefix
)
return
get_tensor_fn
(
placeholder_names
,
names
,
self
.
_predict_tower_id
,
prefix
=
self
.
_prefix
)
def
_find_input_tensors
(
self
):
pass
@
abstractmethod
@
abstractmethod
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
pass
pass
...
@@ -143,9 +138,6 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -143,9 +138,6 @@ class InferenceRunner(InferenceRunnerBase):
super
(
InferenceRunner
,
self
)
.
__init__
(
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
prefix
=
''
,
extra_hooks
=
extra_hooks
)
input
,
infs
,
prefix
=
''
,
extra_hooks
=
extra_hooks
)
def
_find_input_tensors
(
self
):
return
self
.
_input_source
.
get_input_tensors
()
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_output_tensors
()
out_names
=
inf
.
get_output_tensors
()
fetches
=
self
.
_get_tensors_maybe_in_tower
(
out_names
)
fetches
=
self
.
_get_tensors_maybe_in_tower
(
out_names
)
...
@@ -154,34 +146,22 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -154,34 +146,22 @@ class InferenceRunner(InferenceRunnerBase):
class
FeedfreeInferenceRunner
(
InferenceRunnerBase
):
class
FeedfreeInferenceRunner
(
InferenceRunnerBase
):
""" A callback that runs a list of :class:`Inferencer` on some
""" A callback that runs a list of :class:`Inferencer` on some
:class:`
Tensor
Input`, such as some tensor from a TensorFlow data reading
:class:`
Feedfree
Input`, such as some tensor from a TensorFlow data reading
pipeline.
pipeline.
"""
"""
def
__init__
(
self
,
input
,
infs
,
input_names
=
None
,
prefix
=
''
,
extra_hooks
=
None
):
def
__init__
(
self
,
input
,
infs
,
prefix
=
''
,
extra_hooks
=
None
):
"""
"""
Args:
Args:
input (
Tensor
Input): the input to use. Must have ``size()``.
input (
Feedfree
Input): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
infs (list): list of :class:`Inferencer` to run.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
prefix(str): an prefix used to build the tower. Must be set
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
"""
assert
isinstance
(
input
,
Tensor
Input
),
input
assert
isinstance
(
input
,
Feedfree
Input
),
input
super
(
FeedfreeInferenceRunner
,
self
)
.
__init__
(
super
(
FeedfreeInferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
prefix
=
prefix
,
extra_hooks
=
extra_hooks
)
input
,
infs
,
prefix
=
prefix
,
extra_hooks
=
extra_hooks
)
if
input_names
is
not
None
:
assert
isinstance
(
input_names
,
list
)
self
.
input_names
=
input_names
def
_find_input_tensors
(
self
):
# TODO move mapping to InputSource
tensors
=
self
.
_input_source
.
get_input_tensors
()
placeholders
=
self
.
trainer
.
model
.
get_reused_placehdrs
()
if
self
.
input_names
is
None
:
return
tensors
else
:
return
get_tensors_inputs
(
placeholders
,
tensors
,
self
.
input_names
)
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_output_tensors
()
# all is tensorname
out_names
=
inf
.
get_output_tensors
()
# all is tensorname
...
...
tensorpack/train/input_source.py
View file @
ddebb23c
...
@@ -14,7 +14,7 @@ from abc import ABCMeta, abstractmethod
...
@@ -14,7 +14,7 @@ from abc import ABCMeta, abstractmethod
import
six
import
six
from
six.moves
import
range
,
zip
from
six.moves
import
range
,
zip
from
.utils
import
get_placeholders_by_names
from
.utils
import
get_placeholders_by_names
,
get_tensors_inputs
from
..dataflow
import
DataFlow
,
RepeatedData
from
..dataflow
import
DataFlow
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..tfutils
import
get_op_tensor_name
...
@@ -25,11 +25,12 @@ from ..utils.concurrency import ShareSessionThread
...
@@ -25,11 +25,12 @@ from ..utils.concurrency import ShareSessionThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.base
import
Callback
from
..callbacks.base
import
Callback
__all__
=
[
'InputSource'
,
'FeedfreeInput'
,
__all__
=
[
'InputSource'
,
'FeedInput'
,
'DataParallelFeedInput'
,
'FeedInput'
,
'DataParallelFeedInput'
,
'FeedfreeInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'ZMQInput'
,
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
,
'
DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper
'
]
'
StagingInputWrapper'
,
'ReorderInputSource
'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
...
@@ -73,6 +74,8 @@ class FeedInput(InputSource):
...
@@ -73,6 +74,8 @@ class FeedInput(InputSource):
input_names (list[str]): input names this DataFlow maps to
input_names (list[str]): input names this DataFlow maps to
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
if
input_names
is
not
None
:
assert
isinstance
(
input_names
,
(
list
,
tuple
)),
input_names
self
.
ds
=
ds
self
.
ds
=
ds
self
.
_input_names
=
input_names
self
.
_input_names
=
input_names
...
@@ -213,7 +216,9 @@ class QueueInput(FeedfreeInput):
...
@@ -213,7 +216,9 @@ class QueueInput(FeedfreeInput):
"""
"""
Args:
Args:
ds(DataFlow): the input DataFlow.
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): Defaults to a FIFO queue of size 50.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50.
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
queue
=
queue
self
.
queue
=
queue
...
@@ -227,7 +232,7 @@ class QueueInput(FeedfreeInput):
...
@@ -227,7 +232,7 @@ class QueueInput(FeedfreeInput):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput has to be used with
input placeholders
!"
"QueueInput has to be used with
some InputDesc
!"
if
self
.
queue
is
None
:
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
self
.
queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
...
@@ -259,7 +264,9 @@ class BatchQueueInput(FeedfreeInput):
...
@@ -259,7 +264,9 @@ class BatchQueueInput(FeedfreeInput):
Args:
Args:
ds(DataFlow): the input DataFlow.
ds(DataFlow): the input DataFlow.
batch_size(int): the batch size.
batch_size(int): the batch size.
queue (tf.QueueBase): Defaults to a FIFO queue of size 3000.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 3000.
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
queue
=
queue
self
.
queue
=
queue
...
@@ -273,7 +280,7 @@ class BatchQueueInput(FeedfreeInput):
...
@@ -273,7 +280,7 @@ class BatchQueueInput(FeedfreeInput):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"BatchQueueInput has to be used with
input placeholders
!"
"BatchQueueInput has to be used with
some InputDesc
!"
# prepare placeholders without the first dimension
# prepare placeholders without the first dimension
placehdrs_nobatch
=
[]
placehdrs_nobatch
=
[]
...
@@ -366,6 +373,8 @@ class TensorInput(FeedfreeInput):
...
@@ -366,6 +373,8 @@ class TensorInput(FeedfreeInput):
get_tensor_fn: a function which returns a list of input tensors
get_tensor_fn: a function which returns a list of input tensors
when called. It will be called under a TowerContext.
when called. It will be called under a TowerContext.
size(int): size of this input. Use None to leave it undefined.
size(int): size of this input. Use None to leave it undefined.
input_names (list[str]): input names the tensors maps to. Defaults
to be all the inputs of the model.
"""
"""
self
.
get_tensor_fn
=
get_tensor_fn
self
.
get_tensor_fn
=
get_tensor_fn
if
size
is
not
None
:
if
size
is
not
None
:
...
@@ -491,3 +500,40 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -491,3 +500,40 @@ class StagingInputWrapper(FeedfreeInput):
def
get_unstage_op
(
self
):
def
get_unstage_op
(
self
):
all_outputs
=
list
(
chain
.
from_iterable
(
self
.
_unstage_ops
))
all_outputs
=
list
(
chain
.
from_iterable
(
self
.
_unstage_ops
))
return
tf
.
group
(
*
all_outputs
)
return
tf
.
group
(
*
all_outputs
)
class
ReorderInputSource
(
FeedfreeInput
):
"""
When an InputSource only maps to a subset of the InputDesc of the model,
wrap it with :class:`ReorderInputSource`.
"""
def
__init__
(
self
,
input
,
names
):
"""
Args:
input(TensorInput): a TensorInput, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
assert
isinstance
(
input
,
TensorInput
),
input
self
.
_input
=
input
assert
isinstance
(
names
,
(
list
,
tuple
)),
names
self
.
_names
=
names
def
size
(
self
):
return
self
.
_input
.
size
()
def
setup
(
self
,
model
):
self
.
_all_placehdrs
=
model
.
get_reused_placehdrs
()
self
.
_input
.
setup
(
model
)
def
setup_training
(
self
,
trainer
):
self
.
_all_placehdrs
=
trainer
.
model
.
get_reused_placehdrs
()
self
.
_input
.
setup_training
(
trainer
)
def
reset_state
(
self
):
self
.
_input
.
reset_state
()
def
get_input_tensors
(
self
):
ret
=
self
.
_input
.
get_input_tensors
()
return
get_tensors_inputs
(
self
.
_all_placehdrs
,
ret
,
self
.
_names
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment