Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3951aaf7
Commit
3951aaf7
authored
Jul 14, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
merge Feed/Feedfree InferenceRunner
parent
05d1cbe7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
40 deletions
+31
-40
examples/PennTreebank/PTB-LSTM.py
examples/PennTreebank/PTB-LSTM.py
+2
-2
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+16
-34
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+13
-4
No files found.
examples/PennTreebank/PTB-LSTM.py
View file @
3951aaf7
...
@@ -142,9 +142,9 @@ def get_config():
...
@@ -142,9 +142,9 @@ def get_config():
'learning_rate'
,
'learning_rate'
,
lambda
e
,
x
:
x
*
0.80
if
e
>
6
else
x
),
lambda
e
,
x
:
x
*
0.80
if
e
>
6
else
x
),
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
Feedfree
InferenceRunner
(
val_data
,
[
ScalarStats
([
'cost'
])]),
InferenceRunner
(
val_data
,
[
ScalarStats
([
'cost'
])]),
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
Feedfree
InferenceRunner
(
InferenceRunner
(
test_data
,
test_data
,
[
ScalarStats
([
'cost'
],
prefix
=
'test'
)],
prefix
=
'test'
),
[
ScalarStats
([
'cost'
],
prefix
=
'test'
)],
prefix
=
'test'
),
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
...
...
tensorpack/callbacks/inference_runner.py
View file @
3951aaf7
...
@@ -13,11 +13,14 @@ import six
...
@@ -13,11 +13,14 @@ import six
from
six.moves
import
range
from
six.moves
import
range
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.develop
import
deprecated
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
from
..graph_builder.input_source_base
import
InputSource
from
..graph_builder.input_source
import
(
from
..graph_builder.input_source
import
(
FeedInput
,
DataParallelFeedInput
,
FeedfreeInput
)
FeedInput
,
DataParallelFeedInput
,
FeedfreeInput
,
TensorInput
)
from
..predict
import
PredictorTowerBuilder
from
..predict
import
PredictorTowerBuilder
from
.base
import
Callback
from
.base
import
Callback
...
@@ -118,21 +121,23 @@ class InferenceRunnerBase(Callback):
...
@@ -118,21 +121,23 @@ class InferenceRunnerBase(Callback):
class
InferenceRunner
(
InferenceRunnerBase
):
class
InferenceRunner
(
InferenceRunnerBase
):
"""
"""
A callback that runs a list of :class:`Inferencer` on some
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
:class:`DataFlow`.
"""
"""
def
__init__
(
self
,
input
,
infs
,
extra_hooks
=
None
):
def
__init__
(
self
,
input
,
infs
,
prefix
=
''
,
extra_hooks
=
None
):
"""
"""
Args:
Args:
input (FeedInput or DataFlow): the FeedInput, or the DataFlow to run inferencer on.
input (InputSource or DataFlow): The :class:`InputSource` to run
infs (list): a list of `Inferencer` instances.
inference on. If given a DataFlow, will use :class:`FeedInput`.
infs (list): a list of :class:`Inferencer` instances.
"""
"""
if
isinstance
(
input
,
DataFlow
):
if
isinstance
(
input
,
DataFlow
):
input
=
FeedInput
(
input
)
input
=
FeedInput
(
input
)
assert
isinstance
(
input
,
FeedInput
),
input
assert
isinstance
(
input
,
InputSource
),
input
if
isinstance
(
input
,
FeedfreeInput
):
# TODO support other input
assert
isinstance
(
input
,
TensorInput
),
"InferenceRunner only accepts TensorInput or FeedInput!"
super
(
InferenceRunner
,
self
)
.
__init__
(
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
prefix
=
''
,
extra_hooks
=
extra_hooks
)
input
,
infs
,
prefix
=
prefix
,
extra_hooks
=
extra_hooks
)
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_output_tensors
()
out_names
=
inf
.
get_output_tensors
()
...
@@ -140,32 +145,9 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -140,32 +145,9 @@ class InferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
fetches
)
return
InferencerToHook
(
inf
,
fetches
)
class
FeedfreeInferenceRunner
(
InferenceRunnerBase
):
@
deprecated
(
"Just use InferenceRunner since it now accepts TensorInput!"
)
""" A callback that runs a list of :class:`Inferencer` on some
def
FeedfreeInferenceRunner
(
*
args
,
**
kwargs
):
:class:`FeedfreeInput`, such as some tensor from a TensorFlow data reading
return
InferenceRunner
(
*
args
,
**
kwargs
)
pipeline.
"""
def
__init__
(
self
,
input
,
infs
,
prefix
=
''
,
extra_hooks
=
None
):
"""
Args:
input (FeedfreeInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run.
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
assert
isinstance
(
input
,
FeedfreeInput
),
input
super
(
FeedfreeInferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
prefix
=
prefix
,
extra_hooks
=
extra_hooks
)
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_output_tensors
()
# all is tensorname
placeholder_names
=
[
k
.
name
+
':0'
for
k
in
self
.
trainer
.
model
.
get_inputs_desc
()]
ret
=
[]
for
name
in
out_names
:
assert
name
not
in
placeholder_names
,
"Currently inferencer don't support fetching placeholders!"
ret
.
append
(
self
.
_tower_handle
.
get_tensors
([
name
])[
0
])
return
InferencerToHook
(
inf
,
ret
)
# TODO some scripts to test
# TODO some scripts to test
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
3951aaf7
...
@@ -14,16 +14,22 @@ __all__ = ['PredictorFactory']
...
@@ -14,16 +14,22 @@ __all__ = ['PredictorFactory']
class
PredictorTowerHandle
(
object
):
class
PredictorTowerHandle
(
object
):
def
__init__
(
self
,
tower_name
,
input_
tensors
):
def
__init__
(
self
,
tower_name
,
input_
desc_names
,
input_tensors
=
None
):
self
.
_tower_name
=
tower_name
self
.
_tower_name
=
tower_name
self
.
_input_tensors
=
input_tensors
self
.
_input_desc_names
=
[
get_op_tensor_name
(
k
)[
1
]
for
k
in
input_desc_names
]
if
input_tensors
is
not
None
:
self
.
_input_names
=
[
get_op_tensor_name
(
k
.
name
)[
1
]
for
k
in
input_tensors
]
self
.
_input_names
=
[
get_op_tensor_name
(
k
.
name
)[
1
]
for
k
in
input_tensors
]
else
:
self
.
_input_names
=
self
.
_input_desc_names
def
get_tensors
(
self
,
names
):
def
get_tensors
(
self
,
names
):
def
maybe_inside_tower
(
name
):
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
1
]
name
=
get_op_tensor_name
(
name
)[
1
]
if
name
in
self
.
_input_names
:
if
name
in
self
.
_input_names
:
return
name
return
name
elif
name
in
self
.
_input_desc_names
:
idx
=
self
.
_input_desc_names
.
index
(
name
)
return
self
.
_input_names
[
idx
]
else
:
else
:
# if the name is not a placeholder, use it's name in each tower
# if the name is not a placeholder, use it's name in each tower
return
self
.
_tower_name
+
'/'
+
name
return
self
.
_tower_name
+
'/'
+
name
...
@@ -62,7 +68,10 @@ class PredictorFactory(object):
...
@@ -62,7 +68,10 @@ class PredictorFactory(object):
input
=
input
.
get_input_tensors
()
input
=
input
.
get_input_tensors
()
assert
isinstance
(
input
,
(
list
,
tuple
)),
input
assert
isinstance
(
input
,
(
list
,
tuple
)),
input
self
.
_model
.
build_graph
(
input
)
self
.
_model
.
build_graph
(
input
)
self
.
_names_built
[
tower_name
]
=
PredictorTowerHandle
(
tower_name
,
input
)
desc_names
=
[
k
.
name
for
k
in
self
.
_model
.
get_inputs_desc
()]
self
.
_names_built
[
tower_name
]
=
PredictorTowerHandle
(
tower_name
,
desc_names
,
input
)
return
self
.
_names_built
[
tower_name
]
return
self
.
_names_built
[
tower_name
]
def
has_built
(
self
,
tower_name
):
def
has_built
(
self
,
tower_name
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment