Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f2ca6b1a
Commit
f2ca6b1a
authored
Nov 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FasterRCNN] parameterize `get_tf_nms` and make it in a standalone graph.
parent
a042f821
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
10 deletions
+10
-10
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+10
-9
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+0
-1
No files found.
examples/FasterRCNN/eval.py
View file @
f2ca6b1a
...
...
@@ -27,17 +27,18 @@ DetectionResult = namedtuple(
@
memoized
def
get_tf_nms
():
def
get_tf_nms
(
num_output
,
thresh
):
"""
Get a NMS callable.
"""
boxes
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
4
])
scores
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
])
indices
=
tf
.
image
.
non_max_suppression
(
boxes
,
scores
,
config
.
RESULTS_PER_IM
,
config
.
FASTRCNN_NMS_THRESH
)
sess
=
tf
.
Session
(
config
=
get_default_sess_config
())
return
sess
.
make_callable
(
indices
,
[
boxes
,
scores
])
# create a new graph for it
with
tf
.
Graph
()
.
as_default
(),
tf
.
device
(
'/cpu:0'
):
boxes
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
,
4
])
scores
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
[
None
])
indices
=
tf
.
image
.
non_max_suppression
(
boxes
,
scores
,
num_output
,
thresh
)
sess
=
tf
.
Session
(
config
=
get_default_sess_config
())
return
sess
.
make_callable
(
indices
,
[
boxes
,
scores
])
def
nms_fastrcnn_results
(
boxes
,
probs
):
...
...
@@ -53,7 +54,7 @@ def nms_fastrcnn_results(boxes, probs):
boxes
=
boxes
.
copy
()
boxes_per_class
=
{}
nms_func
=
get_tf_nms
()
nms_func
=
get_tf_nms
(
config
.
RESULTS_PER_IM
,
config
.
FASTRCNN_NMS_THRESH
)
ret
=
[]
for
klass
in
range
(
1
,
C
):
ids
=
np
.
where
(
probs
[:,
klass
]
>
config
.
RESULT_SCORE_THRESH
)[
0
]
...
...
examples/FasterRCNN/train.py
View file @
f2ca6b1a
...
...
@@ -232,7 +232,6 @@ class EvalCallback(Callback):
def
_setup_graph
(
self
):
self
.
pred
=
self
.
trainer
.
get_predictor
([
'image'
],
[
'fastrcnn_fg_probs'
,
'fastrcnn_fg_boxes'
])
self
.
df
=
PrefetchDataZMQ
(
get_eval_dataflow
(),
1
)
get_tf_nms
()
# just to make sure the nms part of graph is created
def
_before_train
(
self
):
EVAL_TIMES
=
5
# eval 5 times during training
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment