Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bbf29a18
Commit
bbf29a18
authored
Oct 06, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
assert class ids not out of bounds (#1336)
parent
17cb3554
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
1 deletion
+4
-1
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+3
-0
examples/FasterRCNN/modeling/model_mrcnn.py
examples/FasterRCNN/modeling/model_mrcnn.py
+1
-1
No files found.
examples/FasterRCNN/data.py
View file @
bbf29a18
...
...
@@ -48,6 +48,8 @@ def print_class_histogram(roidbs):
# filter crowd?
gt_inds
=
np
.
where
((
entry
[
"class"
]
>
0
)
&
(
entry
[
"is_crowd"
]
==
0
))[
0
]
gt_classes
=
entry
[
"class"
][
gt_inds
]
if
len
(
gt_classes
):
assert
gt_classes
.
max
()
<=
len
(
class_names
)
-
1
gt_hist
+=
np
.
histogram
(
gt_classes
,
bins
=
hist_bins
)[
0
]
data
=
list
(
itertools
.
chain
(
*
[[
class_names
[
i
+
1
],
v
]
for
i
,
v
in
enumerate
(
gt_hist
[
1
:])]))
COL
=
min
(
6
,
len
(
data
))
...
...
@@ -97,6 +99,7 @@ class TrainingDataPreprocessor:
points
=
tfms
.
apply_coords
(
points
)
boxes
=
point8_to_box
(
points
)
if
len
(
boxes
):
assert
klass
.
max
()
<=
cfg
.
DATA
.
NUM_CATEGORY
,
"Invalid category {}!"
.
format
(
klass
.
max
())
assert
np
.
min
(
np_area
(
boxes
))
>
0
,
"Some boxes have zero area!"
ret
=
{
"image"
:
im
}
...
...
examples/FasterRCNN/modeling/model_mrcnn.py
View file @
bbf29a18
...
...
@@ -23,12 +23,12 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
if
get_tf_version_tuple
()
>=
(
1
,
14
):
mask_logits
=
tf
.
gather
(
mask_logits
,
tf
.
reshape
(
fg_labels
-
1
,
[
-
1
,
1
]),
batch_dims
=
1
)
mask_logits
=
tf
.
squeeze
(
mask_logits
,
axis
=
1
)
else
:
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
fg_labels
,
out_type
=
tf
.
int64
)),
fg_labels
-
1
],
axis
=
1
)
# #fgx2
mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #fg x h x w
mask_logits
=
tf
.
squeeze
(
mask_logits
,
axis
=
1
)
mask_probs
=
tf
.
sigmoid
(
mask_logits
)
# add some training visualizations to tensorboard
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment