Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
438aef79
Commit
438aef79
authored
Aug 25, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] handle empty forground in frcnn head.
parent
6041a1a4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
27 deletions
+22
-27
examples/FasterRCNN/model_frcnn.py
examples/FasterRCNN/model_frcnn.py
+7
-4
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+14
-22
tensorpack/tfutils/collection.py
tensorpack/tfutils/collection.py
+1
-1
No files found.
examples/FasterRCNN/model_frcnn.py
View file @
438aef79
...
@@ -144,6 +144,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
...
@@ -144,6 +144,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
fg_inds
=
tf
.
where
(
labels
>
0
)[:,
0
]
fg_inds
=
tf
.
where
(
labels
>
0
)[:,
0
]
fg_labels
=
tf
.
gather
(
labels
,
fg_inds
)
fg_labels
=
tf
.
gather
(
labels
,
fg_inds
)
num_fg
=
tf
.
size
(
fg_inds
,
out_type
=
tf
.
int64
)
num_fg
=
tf
.
size
(
fg_inds
,
out_type
=
tf
.
int64
)
empty_fg
=
tf
.
equal
(
num_fg
,
0
)
if
int
(
fg_box_logits
.
shape
[
1
])
>
1
:
if
int
(
fg_box_logits
.
shape
[
1
])
>
1
:
indices
=
tf
.
stack
(
indices
=
tf
.
stack
(
[
tf
.
range
(
num_fg
),
fg_labels
],
axis
=
1
)
# #fgx2
[
tf
.
range
(
num_fg
),
fg_labels
],
axis
=
1
)
# #fgx2
...
@@ -157,16 +158,18 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
...
@@ -157,16 +158,18 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
accuracy
=
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
)
accuracy
=
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
)
fg_label_pred
=
tf
.
argmax
(
tf
.
gather
(
label_logits
,
fg_inds
),
axis
=
1
)
fg_label_pred
=
tf
.
argmax
(
tf
.
gather
(
label_logits
,
fg_inds
),
axis
=
1
)
num_zero
=
tf
.
reduce_sum
(
tf
.
to_int64
(
tf
.
equal
(
fg_label_pred
,
0
)),
name
=
'num_zero'
)
num_zero
=
tf
.
reduce_sum
(
tf
.
to_int64
(
tf
.
equal
(
fg_label_pred
,
0
)),
name
=
'num_zero'
)
false_negative
=
tf
.
truediv
(
num_zero
,
num_fg
,
name
=
'false_negative'
)
false_negative
=
tf
.
where
(
fg_accuracy
=
tf
.
reduce_mean
(
empty_fg
,
0.
,
tf
.
truediv
(
num_zero
,
num_fg
),
name
=
'false_negative'
)
tf
.
gather
(
correct
,
fg_inds
),
name
=
'fg_accuracy'
)
fg_accuracy
=
tf
.
where
(
empty_fg
,
0.
,
tf
.
reduce_mean
(
tf
.
gather
(
correct
,
fg_inds
)),
name
=
'fg_accuracy'
)
box_loss
=
tf
.
losses
.
huber_loss
(
box_loss
=
tf
.
losses
.
huber_loss
(
fg_boxes
,
fg_box_logits
,
reduction
=
tf
.
losses
.
Reduction
.
SUM
)
fg_boxes
,
fg_box_logits
,
reduction
=
tf
.
losses
.
Reduction
.
SUM
)
box_loss
=
tf
.
truediv
(
box_loss
=
tf
.
truediv
(
box_loss
,
tf
.
to_float
(
tf
.
shape
(
labels
)[
0
]),
name
=
'box_loss'
)
box_loss
,
tf
.
to_float
(
tf
.
shape
(
labels
)[
0
]),
name
=
'box_loss'
)
add_moving_summary
(
label_loss
,
box_loss
,
accuracy
,
fg_accuracy
,
false_negative
)
add_moving_summary
(
label_loss
,
box_loss
,
accuracy
,
fg_accuracy
,
false_negative
,
tf
.
to_float
(
num_fg
,
name
=
'num_fg_label'
))
return
label_loss
,
box_loss
return
label_loss
,
box_loss
...
...
examples/FasterRCNN/train.py
View file @
438aef79
...
@@ -141,12 +141,13 @@ class ResNetC4Model(DetectionModel):
...
@@ -141,12 +141,13 @@ class ResNetC4Model(DetectionModel):
tf
.
constant
(
cfg
.
FRCNN
.
BBOX_REG_WEIGHTS
,
dtype
=
tf
.
float32
))
tf
.
constant
(
cfg
.
FRCNN
.
BBOX_REG_WEIGHTS
,
dtype
=
tf
.
float32
))
if
is_training
:
if
is_training
:
all_losses
=
[]
# rpn loss
# rpn loss
rpn_label_loss
,
rpn_box_loss
=
rpn_losses
(
all_losses
.
extend
(
rpn_losses
(
anchors
.
gt_labels
,
anchors
.
encoded_gt_boxes
(),
rpn_label_logits
,
rpn_box_logits
)
anchors
.
gt_labels
,
anchors
.
encoded_gt_boxes
(),
rpn_label_logits
,
rpn_box_logits
)
)
# fastrcnn loss
# fastrcnn loss
fastrcnn_label_loss
,
fastrcnn_box_loss
=
fastrcnn_head
.
losses
(
)
all_losses
.
extend
(
fastrcnn_head
.
losses
()
)
if
cfg
.
MODE_MASK
:
if
cfg
.
MODE_MASK
:
# maskrcnn loss
# maskrcnn loss
...
@@ -161,18 +162,13 @@ class ResNetC4Model(DetectionModel):
...
@@ -161,18 +162,13 @@ class ResNetC4Model(DetectionModel):
proposals
.
fg_inds_wrt_gt
,
14
,
proposals
.
fg_inds_wrt_gt
,
14
,
pad_border
=
False
)
# nfg x 1x14x14
pad_border
=
False
)
# nfg x 1x14x14
target_masks_for_fg
=
tf
.
squeeze
(
target_masks_for_fg
,
1
,
'sampled_fg_mask_targets'
)
target_masks_for_fg
=
tf
.
squeeze
(
target_masks_for_fg
,
1
,
'sampled_fg_mask_targets'
)
mrcnn_loss
=
maskrcnn_loss
(
mask_logits
,
proposals
.
fg_labels
(),
target_masks_for_fg
)
all_losses
.
append
(
maskrcnn_loss
(
mask_logits
,
proposals
.
fg_labels
(),
target_masks_for_fg
))
else
:
mrcnn_loss
=
0.0
wd_cost
=
regularize_cost
(
wd_cost
=
regularize_cost
(
'.*/W'
,
l2_regularizer
(
cfg
.
TRAIN
.
WEIGHT_DECAY
),
name
=
'wd_cost'
)
'.*/W'
,
l2_regularizer
(
cfg
.
TRAIN
.
WEIGHT_DECAY
),
name
=
'wd_cost'
)
all_losses
.
append
(
wd_cost
)
total_cost
=
tf
.
add_n
([
total_cost
=
tf
.
add_n
(
all_losses
,
'total_cost'
)
rpn_label_loss
,
rpn_box_loss
,
fastrcnn_label_loss
,
fastrcnn_box_loss
,
mrcnn_loss
,
wd_cost
],
'total_cost'
)
add_moving_summary
(
total_cost
,
wd_cost
)
add_moving_summary
(
total_cost
,
wd_cost
)
return
total_cost
return
total_cost
else
:
else
:
...
@@ -272,11 +268,11 @@ class ResNetFPNModel(DetectionModel):
...
@@ -272,11 +268,11 @@ class ResNetFPNModel(DetectionModel):
tf
.
constant
(
cfg
.
FRCNN
.
BBOX_REG_WEIGHTS
,
dtype
=
tf
.
float32
))
tf
.
constant
(
cfg
.
FRCNN
.
BBOX_REG_WEIGHTS
,
dtype
=
tf
.
float32
))
if
is_training
:
if
is_training
:
# rpn loss:
all_losses
=
[]
rpn_label_loss
,
rpn_box_loss
=
multilevel_rpn_losses
(
all_losses
.
extend
(
multilevel_rpn_losses
(
multilevel_anchors
,
multilevel_label_logits
,
multilevel_box_logits
)
multilevel_anchors
,
multilevel_label_logits
,
multilevel_box_logits
)
)
fastrcnn_label_loss
,
fastrcnn_box_loss
=
fastrcnn_head
.
losses
(
)
all_losses
.
extend
(
fastrcnn_head
.
losses
()
)
if
cfg
.
MODE_MASK
:
if
cfg
.
MODE_MASK
:
# maskrcnn loss
# maskrcnn loss
...
@@ -293,17 +289,13 @@ class ResNetFPNModel(DetectionModel):
...
@@ -293,17 +289,13 @@ class ResNetFPNModel(DetectionModel):
proposals
.
fg_inds_wrt_gt
,
28
,
proposals
.
fg_inds_wrt_gt
,
28
,
pad_border
=
False
)
# fg x 1x28x28
pad_border
=
False
)
# fg x 1x28x28
target_masks_for_fg
=
tf
.
squeeze
(
target_masks_for_fg
,
1
,
'sampled_fg_mask_targets'
)
target_masks_for_fg
=
tf
.
squeeze
(
target_masks_for_fg
,
1
,
'sampled_fg_mask_targets'
)
mrcnn_loss
=
maskrcnn_loss
(
mask_logits
,
proposals
.
fg_labels
(),
target_masks_for_fg
)
all_losses
.
append
(
maskrcnn_loss
(
mask_logits
,
proposals
.
fg_labels
(),
target_masks_for_fg
))
else
:
mrcnn_loss
=
0.0
wd_cost
=
regularize_cost
(
wd_cost
=
regularize_cost
(
'.*/W'
,
l2_regularizer
(
cfg
.
TRAIN
.
WEIGHT_DECAY
),
name
=
'wd_cost'
)
'.*/W'
,
l2_regularizer
(
cfg
.
TRAIN
.
WEIGHT_DECAY
),
name
=
'wd_cost'
)
all_losses
.
append
(
wd_cost
)
total_cost
=
tf
.
add_n
([
rpn_label_loss
,
rpn_box_loss
,
total_cost
=
tf
.
add_n
(
all_losses
,
'total_cost'
)
fastrcnn_label_loss
,
fastrcnn_box_loss
,
mrcnn_loss
,
wd_cost
],
'total_cost'
)
add_moving_summary
(
total_cost
,
wd_cost
)
add_moving_summary
(
total_cost
,
wd_cost
)
return
total_cost
return
total_cost
else
:
else
:
...
...
tensorpack/tfutils/collection.py
View file @
438aef79
...
@@ -141,7 +141,7 @@ class CollectionGuard(object):
...
@@ -141,7 +141,7 @@ class CollectionGuard(object):
size_change
.
append
((
self
.
_key_name
(
k
),
len
(
old_v
),
len
(
v
)))
size_change
.
append
((
self
.
_key_name
(
k
),
len
(
old_v
),
len
(
v
)))
if
newly_created
:
if
newly_created
:
logger
.
info
(
logger
.
info
(
"New collections created in {}: {}"
.
format
(
"New collections created in
tower
{}: {}"
.
format
(
self
.
_name
,
', '
.
join
(
newly_created
)))
self
.
_name
,
', '
.
join
(
newly_created
)))
if
size_change
:
if
size_change
:
logger
.
info
(
logger
.
info
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment