Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
21c7f94a
Commit
21c7f94a
authored
Aug 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
InferenceRunner now works with QueueInput (#139)
parent
ac2031df
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
11 deletions
+21
-11
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+2
-2
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+17
-7
tensorpack/graph_builder/input_source.py
tensorpack/graph_builder/input_source.py
+2
-2
No files found.
tensorpack/callbacks/group.py
View file @
21c7f94a
...
@@ -44,8 +44,8 @@ class CallbackTimeLogger(object):
...
@@ -44,8 +44,8 @@ class CallbackTimeLogger(object):
class
Callbacks
(
Callback
):
class
Callbacks
(
Callback
):
"""
"""
A container to hold all callbacks, and
execute them in the right order
A container to hold all callbacks, and
trigger them iteratively.
(e.g. :class:`StatPrinter` will be executed at last)
.
Note that it does nothing to before_run/after_run
.
"""
"""
def
__init__
(
self
,
cbs
):
def
__init__
(
self
,
cbs
):
...
...
tensorpack/callbacks/inference_runner.py
View file @
21c7f94a
...
@@ -20,9 +20,10 @@ from ..dataflow.base import DataFlow, DataFlowTerminated
...
@@ -20,9 +20,10 @@ from ..dataflow.base import DataFlow, DataFlowTerminated
from
..graph_builder.input_source_base
import
InputSource
from
..graph_builder.input_source_base
import
InputSource
from
..graph_builder.input_source
import
(
from
..graph_builder.input_source
import
(
FeedInput
,
DataParallelFeedInput
,
FeedfreeInput
,
TensorInput
)
FeedInput
,
DataParallelFeedInput
)
from
.base
import
Callback
from
.base
import
Callback
from
.group
import
Callbacks
from
.inference
import
Inferencer
from
.inference
import
Inferencer
from
.hooks
import
CallbackToHook
from
.hooks
import
CallbackToHook
...
@@ -79,20 +80,29 @@ class InferenceRunnerBase(Callback):
...
@@ -79,20 +80,29 @@ class InferenceRunnerBase(Callback):
tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
tower_id
=
self
.
trainer
.
config
.
predict_tower
[
0
]
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
self
.
_input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
self
.
_tower_handle
=
self
.
trainer
.
predictor_factory
.
build
(
self
.
_tower_handle
=
self
.
trainer
.
predictor_factory
.
build
(
self
.
_tower_name
,
device
,
self
.
_input_source
)
self
.
_tower_name
,
device
,
self
.
_input_source
)
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
.
extend
([
CallbackToHook
(
cb
)
for
cb
in
self
.
_input_callbacks
])
# trigger_{step,epoch}, {before,after}_epoch is ignored.
# We assume that InputSource callbacks won't use these methods
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
self
.
_hooks
.
extend
(
self
.
_input_callbacks
.
get_hooks
())
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
setup_graph
(
self
.
trainer
)
inf
.
setup_graph
(
self
.
trainer
)
self
.
_input_callbacks
.
setup_graph
(
self
.
trainer
)
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_hooks
.
extend
(
self
.
_extra_hooks
)
self
.
_hooks
.
extend
(
self
.
_extra_hooks
)
self
.
_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks
)
self
.
_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks
)
self
.
_input_callbacks
.
before_train
()
def
_after_train
(
self
):
self
.
_input_callbacks
.
after_train
()
@
abstractmethod
@
abstractmethod
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
...
@@ -108,9 +118,11 @@ class InferenceRunnerBase(Callback):
...
@@ -108,9 +118,11 @@ class InferenceRunnerBase(Callback):
try
:
try
:
for
_
in
tqdm
.
trange
(
self
.
_size
,
**
get_tqdm_kwargs
()):
for
_
in
tqdm
.
trange
(
self
.
_size
,
**
get_tqdm_kwargs
()):
self
.
_hooked_sess
.
run
(
fetches
=
[])
self
.
_hooked_sess
.
run
(
fetches
=
[])
except
(
StopIteration
,
DataFlowTerminated
):
except
(
StopIteration
,
DataFlowTerminated
,
logger
.
exception
(
tf
.
errors
.
CancelledError
,
tf
.
errors
.
OutOfRangeError
):
logger
.
error
(
"[InferenceRunner] input stopped before reaching its size()! "
+
msg
)
"[InferenceRunner] input stopped before reaching its size()! "
+
msg
)
raise
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
trigger_epoch
()
inf
.
trigger_epoch
()
...
@@ -130,8 +142,6 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -130,8 +142,6 @@ class InferenceRunner(InferenceRunnerBase):
if
isinstance
(
input
,
DataFlow
):
if
isinstance
(
input
,
DataFlow
):
input
=
FeedInput
(
input
,
infinite
=
False
)
input
=
FeedInput
(
input
,
infinite
=
False
)
assert
isinstance
(
input
,
InputSource
),
input
assert
isinstance
(
input
,
InputSource
),
input
if
isinstance
(
input
,
FeedfreeInput
):
# TODO support other input
assert
isinstance
(
input
,
TensorInput
),
"InferenceRunner only accepts TensorInput or FeedInput!"
super
(
InferenceRunner
,
self
)
.
__init__
(
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
,
tower_name
=
tower_name
,
extra_hooks
=
extra_hooks
)
input
,
infs
,
tower_name
=
tower_name
,
extra_hooks
=
extra_hooks
)
...
...
tensorpack/graph_builder/input_source.py
View file @
21c7f94a
...
@@ -169,7 +169,7 @@ class FeedfreeInput(InputSource):
...
@@ -169,7 +169,7 @@ class FeedfreeInput(InputSource):
class
EnqueueThread
(
ShareSessionThread
):
class
EnqueueThread
(
ShareSessionThread
):
def
__init__
(
self
,
queue
,
ds
,
placehdrs
):
def
__init__
(
self
,
queue
,
ds
,
placehdrs
):
super
(
EnqueueThread
,
self
)
.
__init__
()
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread
'
self
.
name
=
'EnqueueThread
'
+
queue
.
name
self
.
daemon
=
True
self
.
daemon
=
True
self
.
dataflow
=
ds
self
.
dataflow
=
ds
...
@@ -222,7 +222,6 @@ class QueueInput(FeedfreeInput):
...
@@ -222,7 +222,6 @@ class QueueInput(FeedfreeInput):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
_input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_input_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
assert
len
(
self
.
_input_placehdrs
)
>
0
,
\
assert
len
(
self
.
_input_placehdrs
)
>
0
,
\
"QueueInput has to be used with some inputs!"
"QueueInput has to be used with some inputs!"
...
@@ -231,6 +230,7 @@ class QueueInput(FeedfreeInput):
...
@@ -231,6 +230,7 @@ class QueueInput(FeedfreeInput):
self
.
queue
=
tf
.
FIFOQueue
(
self
.
queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
_input_placehdrs
],
50
,
[
x
.
dtype
for
x
in
self
.
_input_placehdrs
],
name
=
'input_queue'
)
name
=
'input_queue'
)
logger
.
info
(
"Setting up the queue '{}' for CPU prefetching ..."
.
format
(
self
.
queue
.
name
))
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_input_placehdrs
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
_input_placehdrs
)
def
_create_ema_callback
(
self
):
def
_create_ema_callback
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment