Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8f8fe80d
Commit
8f8fe80d
authored
Sep 26, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
FeedfreePredictor and example on ImageNet eval (fix #772)
parent
7f505225
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
106 additions
and
14 deletions
+106
-14
examples/ImageNetModels/imagenet_utils.py
examples/ImageNetModels/imagenet_utils.py
+12
-5
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+3
-3
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+7
-0
tensorpack/predict/feedfree.py
tensorpack/predict/feedfree.py
+72
-0
tensorpack/tfutils/dependency.py
tensorpack/tfutils/dependency.py
+2
-1
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+1
-1
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+9
-4
No files found.
examples/ImageNetModels/imagenet_utils.py
View file @
8f8fe80d
...
...
@@ -4,15 +4,17 @@
import
cv2
import
numpy
as
np
import
tqdm
import
multiprocessing
import
tensorflow
as
tf
from
abc
import
abstractmethod
from
tensorpack
import
imgaug
,
dataset
,
ModelDesc
from
tensorpack
import
ModelDesc
from
tensorpack.input_source
import
QueueInput
,
StagingInput
from
tensorpack.dataflow
import
(
AugmentImageComponent
,
PrefetchDataZMQ
,
imgaug
,
dataset
,
AugmentImageComponent
,
PrefetchDataZMQ
,
BatchData
,
MultiThreadMapData
)
from
tensorpack.predict
import
PredictConfig
,
SimpleDataset
Predictor
from
tensorpack.predict
import
PredictConfig
,
Feedfree
Predictor
from
tensorpack.utils.stats
import
RatioCounter
from
tensorpack.models
import
regularize_cost
from
tensorpack.tfutils.summary
import
add_moving_summary
...
...
@@ -126,12 +128,17 @@ def eval_on_ILSVRC12(model, sessinit, dataflow):
input_names
=
[
'input'
,
'label'
],
output_names
=
[
'wrong-top1'
,
'wrong-top5'
]
)
pred
=
SimpleDatasetPredictor
(
pred_config
,
dataflow
)
acc1
,
acc5
=
RatioCounter
(),
RatioCounter
()
for
top1
,
top5
in
pred
.
get_result
():
# This does not have a visible improvement over naive predictor,
# but will have an improvement if image_dtype is set to float32.
pred
=
FeedfreePredictor
(
pred_config
,
StagingInput
(
QueueInput
(
dataflow
),
device
=
'/gpu:0'
))
for
_
in
tqdm
.
trange
(
dataflow
.
size
()):
top1
,
top5
=
pred
()
batch_size
=
top1
.
shape
[
0
]
acc1
.
feed
(
top1
.
sum
(),
batch_size
)
acc5
.
feed
(
top5
.
sum
(),
batch_size
)
print
(
"Top1 Error: {}"
.
format
(
acc1
.
ratio
))
print
(
"Top5 Error: {}"
.
format
(
acc5
.
ratio
))
...
...
tensorpack/input_source/input_source.py
View file @
8f8fe80d
...
...
@@ -547,10 +547,10 @@ class StagingInput(FeedfreeInput):
self
.
fetches
=
tf
.
train
.
SessionRunArgs
(
fetches
=
[
self
.
stage_op
,
unstage_op
])
def
_prefill
(
self
):
def
_prefill
(
self
,
sess
):
logger
.
info
(
"Pre-filling StagingArea ..."
)
for
k
in
range
(
self
.
nr_stage
):
self
.
stage_op
.
run
()
self
.
stage_op
.
run
(
session
=
sess
)
logger
.
info
(
"{} element{} put into StagingArea on each tower."
.
format
(
self
.
nr_stage
,
"s were"
if
self
.
nr_stage
>
1
else
" was"
))
...
...
@@ -559,7 +559,7 @@ class StagingInput(FeedfreeInput):
# doing it in `before_train` may not work because QueueInput happens in before_train.
if
not
self
.
_initialized
:
self
.
_initialized
=
True
self
.
_prefill
()
self
.
_prefill
(
ctx
.
session
)
# Only step the stagingarea when the input is evaluated in this sess.run
fetches
=
ctx
.
original_args
.
fetches
if
dependency_of_fetches
(
fetches
,
self
.
_check_dependency_op
):
...
...
tensorpack/input_source/input_source_base.py
View file @
8f8fe80d
...
...
@@ -118,6 +118,13 @@ class InputSource(object):
All callbacks will be automatically marked as `chief_only=False`,
so they will run on all nodes.
Callbacks returned by :class:`InputSource` only supports a subset of callback's functionalities:
1. It cannot access the trainer, because an :class:`InputSource` can be used in pure inference.
2. It cannot use the following methods: `trigger_{step,epoch}, {before,after}_epoch`.
In other words, these callbacks should only have the basic functionality of `tf.train.SessionRunHooks`.
Returns:
list[Callback]: extra callbacks needed by this InputSource.
"""
...
...
tensorpack/predict/feedfree.py
0 → 100644
View file @
8f8fe80d
#!/usr/bin/env python
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
from
.base
import
PredictorBase
from
..tfutils.tower
import
PredictTowerContext
from
..callbacks
import
Callbacks
__all__
=
[
'FeedfreePredictor'
]
class
FeedfreePredictor
(
PredictorBase
):
"""
Create a predictor that takes inputs from an :class:`InputSource`, instead of from feeds.
An instance `pred` of :class:`FeedfreePredictor` can be called only by `pred()`, which returns
a list of output values as defined in config.output_names.
"""
def
__init__
(
self
,
config
,
input_source
):
"""
Args:
config (PredictConfig): the config to use.
input_source (InputSource): the feedfree InputSource to use.
Must match the inputs_desc in config.
"""
self
.
_config
=
config
self
.
_input_source
=
input_source
assert
config
.
return_input
is
False
,
\
"return_input is not supported in FeedfreePredictor! "
\
"If you need to fetch inputs, add the names to the output_names!"
self
.
_hooks
=
[]
self
.
graph
=
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
self
.
_input_callbacks
=
Callbacks
(
self
.
_input_source
.
setup
(
config
.
inputs_desc
))
with
PredictTowerContext
(
''
):
self
.
_input_tensors
=
self
.
_input_source
.
get_input_tensors
()
config
.
tower_func
(
*
self
.
_input_tensors
)
self
.
_tower_handle
=
config
.
tower_func
.
towers
[
-
1
]
self
.
_output_tensors
=
self
.
_tower_handle
.
get_tensors
(
config
.
output_names
)
self
.
_input_callbacks
.
setup_graph
(
None
)
for
h
in
self
.
_input_callbacks
.
get_hooks
():
self
.
_register_hook
(
h
)
self
.
_initialize_session
()
def
_register_hook
(
self
,
hook
):
"""
Args:
hook (tf.train.SessionRunHook):
"""
self
.
_hooks
.
append
(
hook
)
def
_initialize_session
(
self
):
# init the session
self
.
_config
.
session_init
.
_setup_graph
()
self
.
_sess
=
self
.
_config
.
session_creator
.
create_session
()
self
.
_config
.
session_init
.
_run_init
(
self
.
_sess
)
with
self
.
_sess
.
as_default
():
self
.
_input_callbacks
.
before_train
()
self
.
_hooked_sess
=
HookedSession
(
self
.
_sess
,
self
.
_hooks
)
def
__call__
(
self
):
return
self
.
_hooked_sess
.
run
(
self
.
_output_tensors
)
def
_do_call
(
self
):
raise
NotImplementedError
(
"You're calling the wrong function!"
)
tensorpack/tfutils/dependency.py
View file @
8f8fe80d
...
...
@@ -51,7 +51,8 @@ def dependency_of_fetches(fetches, op):
"""
try
:
from
tensorflow.python.client.session
import
_FetchHandler
as
FetchHandler
handler
=
FetchHandler
(
tf
.
get_default_graph
(),
fetches
,
{})
# use the graph of the op, so that this function can be called without being under a default graph
handler
=
FetchHandler
(
op
.
graph
,
fetches
,
{})
targets
=
tuple
(
handler
.
fetches
()
+
handler
.
targets
())
except
ImportError
:
if
isinstance
(
fetches
,
list
):
...
...
tensorpack/tfutils/sesscreate.py
View file @
8f8fe80d
...
...
@@ -22,7 +22,7 @@ class NewSessionCreator(tf.train.ChiefSessionCreator):
"""
Args:
target, graph, config: same as :meth:`Session.__init__()`.
config: defaults to :func:`tfutils.get_default_sess_config()`
config:
a :class:`tf.ConfigProto` instance,
defaults to :func:`tfutils.get_default_sess_config()`
"""
assert
graph
is
None
...
...
tensorpack/tfutils/tower.py
View file @
8f8fe80d
...
...
@@ -377,6 +377,9 @@ class TowerTensorHandle(object):
1. The name of the tensor without any tower prefix.
2. The name of an :class:`InputDesc`, if it is used when building the tower.
In the second case, this method will return the tensor that's used as the corresponding
input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).
"""
name
=
get_op_tensor_name
(
name
)[
1
]
if
len
(
self
.
ns_name
):
...
...
@@ -392,10 +395,12 @@ class TowerTensorHandle(object):
raise
else
:
if
name
in
self
.
_extra_tensor_names
:
logger
.
warn
(
"'{}' may refer to both the tensor '{}' or the input '{}'."
.
format
(
name
,
ret
.
name
,
self
.
_extra_tensor_names
[
name
]
.
name
)
+
"Assuming it is the tensor '{}'."
.
format
(
ret
.
name
))
mapped_tensor
=
self
.
_extra_tensor_names
[
name
]
logger
.
info
(
"'{}' may refer to both the Tensor/Placeholder '{}' or the input to the tower '{}'."
.
format
(
name
,
ret
.
name
,
mapped_tensor
.
name
)
+
" Assuming it is the input '{}'."
.
format
(
mapped_tensor
.
name
))
return
mapped_tensor
return
ret
def
get_tensors
(
self
,
names
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment