Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b747c068
Commit
b747c068
authored
Aug 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Inferencer API renamed
parent
5b681a95
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
38 deletions
+50
-38
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+45
-33
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+5
-5
No files found.
tensorpack/callbacks/inference.py
View file @
b747c068
...
...
@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
import
six
from
six.moves
import
zip
...
...
@@ -15,12 +15,11 @@ from ..tfutils.common import get_op_tensor_name
__all__
=
[
'ScalarStats'
,
'Inferencer'
,
'ClassificationError'
,
'BinaryClassificationStats'
]
# TODO rename get_output_tensors to get_output_names
@
six
.
add_metaclass
(
ABCMeta
)
class
Inferencer
(
Callback
):
""" Base class of Inferencer. To be used with :class:`InferenceRunner`. """
""" Base class of Inferencer.
Inferencer is a special kind of callback that should be called by :class:`InferenceRunner`. """
def
_before_epoch
(
self
):
self
.
_before_inference
()
...
...
@@ -31,20 +30,6 @@ class Inferencer(Callback):
"""
pass
def
datapoint
(
self
,
output
):
"""
Called after each new datapoint finished the forward inference.
Args:
output(list): list of output this inferencer needs. Has the same
length as ``self.get_output_tensors()``.
"""
self
.
_datapoint
(
output
)
@
abstractmethod
def
_datapoint
(
self
,
output
):
pass
def
_trigger_epoch
(
self
):
ret
=
self
.
_after_inference
()
if
ret
is
None
:
...
...
@@ -65,17 +50,44 @@ class Inferencer(Callback):
"""
pass
def
get_
output_tensor
s
(
self
):
def
get_
fetche
s
(
self
):
"""
Return a list of tensor names (guaranteed not op name) this inferencer needs.
"""
ret
=
self
.
_get_output_tensors
()
try
:
ret
=
self
.
_get_fetches
()
except
NotImplementedError
:
logger
.
warn
(
"Inferencer._get_output_tensors was renamed to _get_fetches"
)
ret
=
self
.
_get_output_tensors
()
return
[
get_op_tensor_name
(
n
)[
1
]
for
n
in
ret
]
@
abstractmethod
def
_get_output_tensors
(
self
):
pass
def
_get_fetches
(
self
):
raise
NotImplementedError
()
def
on_fetches
(
self
,
results
):
"""
Called after each new datapoint finished the forward inference.
Args:
results(list): list of results this inferencer fetched. Has the same
length as ``self._get_fetches()``.
"""
try
:
self
.
_on_fetches
(
results
)
except
NotImplementedError
:
logger
.
warn
(
"Inferencer._datapoint was renamed to _on_fetches"
)
self
.
_datapoint
(
results
)
def
_datapoint
(
self
,
results
):
pass
def
_on_fetches
(
self
,
results
):
raise
NotImplementedError
()
class
ScalarStats
(
Inferencer
):
"""
...
...
@@ -96,13 +108,13 @@ class ScalarStats(Inferencer):
self
.
names
=
names
self
.
prefix
=
prefix
def
_get_output_tensors
(
self
):
return
self
.
names
def
_before_inference
(
self
):
self
.
stats
=
[]
def
_datapoint
(
self
,
output
):
def
_get_fetches
(
self
):
return
self
.
names
def
_on_fetches
(
self
,
output
):
self
.
stats
.
append
(
output
)
def
_after_inference
(
self
):
...
...
@@ -142,13 +154,13 @@ class ClassificationError(Inferencer):
self
.
wrong_tensor_name
=
wrong_tensor_name
self
.
summary_name
=
summary_name
def
_get_output_tensors
(
self
):
return
[
self
.
wrong_tensor_name
]
def
_before_inference
(
self
):
self
.
err_stat
=
RatioCounter
()
def
_datapoint
(
self
,
outputs
):
def
_get_fetches
(
self
):
return
[
self
.
wrong_tensor_name
]
def
_on_fetches
(
self
,
outputs
):
vec
=
outputs
[
0
]
# TODO put shape assertion into inference-runner
assert
vec
.
ndim
==
1
,
"{} is not a vector!"
.
format
(
self
.
wrong_tensor_name
)
...
...
@@ -176,13 +188,13 @@ class BinaryClassificationStats(Inferencer):
self
.
label_tensor_name
=
label_tensor_name
self
.
prefix
=
prefix
def
_get_output_tensors
(
self
):
return
[
self
.
pred_tensor_name
,
self
.
label_tensor_name
]
def
_before_inference
(
self
):
self
.
stat
=
BinaryStatistics
()
def
_datapoint
(
self
,
outputs
):
def
_get_fetches
(
self
):
return
[
self
.
pred_tensor_name
,
self
.
label_tensor_name
]
def
_on_fetches
(
self
,
outputs
):
pred
,
label
=
outputs
self
.
stat
.
feed
(
pred
,
label
)
...
...
tensorpack/callbacks/inference_runner.py
View file @
b747c068
...
...
@@ -39,7 +39,7 @@ class InferencerToHook(tf.train.SessionRunHook):
return
tf
.
train
.
SessionRunArgs
(
fetches
=
self
.
_fetches
)
def
after_run
(
self
,
_
,
run_values
):
self
.
_inf
.
datapoint
(
run_values
.
results
)
self
.
_inf
.
on_fetches
(
run_values
.
results
)
@
six
.
add_metaclass
(
ABCMeta
)
...
...
@@ -136,7 +136,7 @@ class InferenceRunner(InferenceRunnerBase):
input
,
infs
,
tower_name
=
tower_name
,
extra_hooks
=
extra_hooks
)
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_
output_tensor
s
()
out_names
=
inf
.
get_
fetche
s
()
fetches
=
self
.
_tower_handle
.
get_tensors
(
out_names
)
return
InferencerToHook
(
inf
,
fetches
)
...
...
@@ -199,16 +199,16 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
res
=
run_values
.
results
for
i
in
range
(
0
,
len
(
res
),
self
.
_sz
):
vals
=
res
[
i
:
i
+
self
.
_sz
]
self
.
_inf
.
datapoint
(
vals
)
self
.
_inf
.
on_fetches
(
vals
)
def
_build_hook_parallel
(
self
,
inf
):
out_names
=
inf
.
get_
output_tensor
s
()
out_names
=
inf
.
get_
fetche
s
()
sz
=
len
(
out_names
)
fetches
=
list
(
itertools
.
chain
(
*
[
t
.
get_tensors
(
out_names
)
for
t
in
self
.
_handles
]))
return
self
.
InferencerToHookDataParallel
(
inf
,
fetches
,
sz
)
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_
output_tensor
s
()
out_names
=
inf
.
get_
fetche
s
()
fetches
=
self
.
_handles
[
0
]
.
get_tensors
(
out_names
)
return
InferencerToHook
(
inf
,
fetches
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment