Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9626ebd8
Commit
9626ebd8
authored
May 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove input_names from InferenceRunnerBase
parent
76fa8e38
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
16 deletions
+20
-16
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+10
-12
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+10
-4
No files found.
tensorpack/callbacks/inference_runner.py
View file @
9626ebd8
...
@@ -56,13 +56,11 @@ def summary_inferencer(trainer, infs):
...
@@ -56,13 +56,11 @@ def summary_inferencer(trainer, infs):
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
InferenceRunnerBase
(
Callback
):
class
InferenceRunnerBase
(
Callback
):
""" Base methods for inference runner"""
""" Base methods for inference runner"""
def
__init__
(
self
,
input
,
infs
,
input_names
=
None
,
prefix
=
''
,
extra_hooks
=
None
):
def
__init__
(
self
,
input
,
infs
,
prefix
=
''
,
extra_hooks
=
None
):
"""
"""
Args:
Args:
input (InputSource): the input to use. Must have ``size()``.
input (InputSource): the input to use. Must have ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
prefix(str): an prefix used to build the tower. Must be set
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used.
differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
...
@@ -74,9 +72,6 @@ class InferenceRunnerBase(Callback):
...
@@ -74,9 +72,6 @@ class InferenceRunnerBase(Callback):
self
.
infs
=
infs
self
.
infs
=
infs
for
v
in
self
.
infs
:
for
v
in
self
.
infs
:
assert
isinstance
(
v
,
Inferencer
),
v
assert
isinstance
(
v
,
Inferencer
),
v
if
input_names
is
not
None
:
assert
isinstance
(
input_names
,
list
)
self
.
input_names
=
input_names
try
:
try
:
self
.
_size
=
input
.
size
()
self
.
_size
=
input
.
size
()
...
@@ -95,7 +90,7 @@ class InferenceRunnerBase(Callback):
...
@@ -95,7 +90,7 @@ class InferenceRunnerBase(Callback):
def
fn
(
_
):
def
fn
(
_
):
in_tensors
=
self
.
_find_input_tensors
()
in_tensors
=
self
.
_find_input_tensors
()
assert
isinstance
(
in_tensors
,
list
),
in_tensors
assert
isinstance
(
in_tensors
,
(
list
,
tuple
)
),
in_tensors
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
self
.
trainer
.
model
.
build_graph
(
in_tensors
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
self
.
_predict_tower_id
)
...
@@ -140,12 +135,13 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -140,12 +135,13 @@ class InferenceRunner(InferenceRunnerBase):
Args:
Args:
ds (DataFlow): the DataFlow to run inferencer on.
ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances.
infs (list): a list of `Inferencer` instances.
input_names (list[str]): same as in :class:`InferenceRunnerBase`.
input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
input
=
FeedInput
(
ds
,
input_names
)
input
=
FeedInput
(
ds
,
input_names
)
super
(
InferenceRunner
,
self
)
.
__init__
(
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
input_names
,
prefix
=
''
,
extra_hooks
=
extra_hooks
)
input
,
infs
,
prefix
=
''
,
extra_hooks
=
extra_hooks
)
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
return
self
.
_input_source
.
get_input_tensors
()
return
self
.
_input_source
.
get_input_tensors
()
...
@@ -173,7 +169,10 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
...
@@ -173,7 +169,10 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
"""
"""
assert
isinstance
(
input
,
TensorInput
),
input
assert
isinstance
(
input
,
TensorInput
),
input
super
(
FeedfreeInferenceRunner
,
self
)
.
__init__
(
super
(
FeedfreeInferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
input_names
,
prefix
=
prefix
,
extra_hooks
=
extra_hooks
)
input
,
infs
,
prefix
=
prefix
,
extra_hooks
=
extra_hooks
)
if
input_names
is
not
None
:
assert
isinstance
(
input_names
,
list
)
self
.
input_names
=
input_names
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
# TODO move mapping to InputSource
# TODO move mapping to InputSource
...
@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
for
k
in
range
(
len
(
gpus
))]
for
k
in
range
(
len
(
gpus
))]
input
=
DataParallelFeedInput
(
input
=
DataParallelFeedInput
(
ds
,
self
.
_tower_names
,
input_names
=
input_names
)
ds
,
self
.
_tower_names
,
input_names
=
input_names
)
super
(
DataParallelInferenceRunner
,
self
)
.
__init__
(
super
(
DataParallelInferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
input
,
infs
,
input_names
)
self
.
_gpus
=
gpus
self
.
_gpus
=
gpus
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
...
...
tensorpack/train/input_source.py
View file @
9626ebd8
...
@@ -25,7 +25,8 @@ from ..utils.concurrency import ShareSessionThread
...
@@ -25,7 +25,8 @@ from ..utils.concurrency import ShareSessionThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.base
import
Callback
from
..callbacks.base
import
Callback
__all__
=
[
'InputSource'
,
'FeedfreeInput'
,
'DataParallelFeedInput'
,
__all__
=
[
'InputSource'
,
'FeedfreeInput'
,
'FeedInput'
,
'DataParallelFeedInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'ZMQInput'
,
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper'
]
'DummyConstantInput'
,
'TensorInput'
,
'StagingInputWrapper'
]
...
@@ -54,12 +55,13 @@ class InputSource(object):
...
@@ -54,12 +55,13 @@ class InputSource(object):
def
reset_state
(
self
):
def
reset_state
(
self
):
pass
pass
@
abstractmethod
def
next_feed
(
self
):
def
next_feed
(
self
):
"""
"""
Returns:
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
a feed_dict of {Tensor: data}, to be used to run the steps
"""
"""
return
{}
pass
class
FeedInput
(
InputSource
):
class
FeedInput
(
InputSource
):
...
@@ -128,6 +130,7 @@ class DataParallelFeedInput(FeedInput):
...
@@ -128,6 +130,7 @@ class DataParallelFeedInput(FeedInput):
# input_names to be used for this specific tower
# input_names to be used for this specific tower
self
.
_feed_placehdrs_per_tower
.
append
(
self
.
_feed_placehdrs_per_tower
.
append
(
get_placeholders_by_names
(
phdrs
,
input_names
))
get_placeholders_by_names
(
phdrs
,
input_names
))
print
(
self
.
_feed_placehdrs_per_tower
[
-
1
])
self
.
reset_state
()
self
.
reset_state
()
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
...
@@ -158,10 +161,13 @@ class FeedfreeInput(InputSource):
...
@@ -158,10 +161,13 @@ class FeedfreeInput(InputSource):
# TODO cannot reset
# TODO cannot reset
pass
pass
def
next_feed
(
self
):
return
{}
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
class
EnqueueThread
(
ShareSessionThread
):
class
EnqueueThread
(
ShareSessionThread
):
def
__init__
(
self
,
queue
,
ds
,
input_
placehdrs
):
def
__init__
(
self
,
queue
,
ds
,
placehdrs
):
super
(
EnqueueThread
,
self
)
.
__init__
()
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread'
self
.
name
=
'EnqueueThread'
self
.
daemon
=
True
self
.
daemon
=
True
...
@@ -169,7 +175,7 @@ class EnqueueThread(ShareSessionThread):
...
@@ -169,7 +175,7 @@ class EnqueueThread(ShareSessionThread):
self
.
dataflow
=
ds
self
.
dataflow
=
ds
self
.
queue
=
queue
self
.
queue
=
queue
self
.
placehdrs
=
input_
placehdrs
self
.
placehdrs
=
placehdrs
self
.
op
=
self
.
queue
.
enqueue
(
self
.
placehdrs
)
self
.
op
=
self
.
queue
.
enqueue
(
self
.
placehdrs
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment