Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
cf97218c
Commit
cf97218c
authored
Aug 25, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] move fastrcnn outputs out of head. support class-agnostic regression
parent
0d36de5f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
30 deletions
+65
-30
examples/FasterRCNN/model_frcnn.py
examples/FasterRCNN/model_frcnn.py
+52
-19
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+3
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+10
-9
No files found.
examples/FasterRCNN/model_frcnn.py
View file @
cf97218c
...
...
@@ -100,22 +100,25 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
@
layer_register
(
log_shape
=
True
)
def
fastrcnn_outputs
(
feature
,
num_classes
):
def
fastrcnn_outputs
(
feature
,
num_classes
,
class_agnostic_regression
=
False
):
"""
Args:
feature (any shape):
num_classes(int): num_category + 1
class_agnostic_regression (bool): if True, regression to N x 1 x 4
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class x 4)
cls_logits: N x num_class classification logits
reg_logits: N x num_classx4 or Nx2x4 if class agnostic
"""
classification
=
FullyConnected
(
'class'
,
feature
,
num_classes
,
kernel_initializer
=
tf
.
random_normal_initializer
(
stddev
=
0.01
))
num_classes_for_box
=
1
if
class_agnostic_regression
else
num_classes
box_regression
=
FullyConnected
(
'box'
,
feature
,
num_classes
*
4
,
'box'
,
feature
,
num_classes
_for_box
*
4
,
kernel_initializer
=
tf
.
random_normal_initializer
(
stddev
=
0.001
))
box_regression
=
tf
.
reshape
(
box_regression
,
(
-
1
,
num_classes
,
4
),
name
=
'output_box'
)
box_regression
=
tf
.
reshape
(
box_regression
,
(
-
1
,
num_classes
_for_box
,
4
),
name
=
'output_box'
)
return
classification
,
box_regression
...
...
@@ -126,7 +129,7 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
labels: n,
label_logits: nxC
fg_boxes: nfgx4, encoded
fg_box_logits: nfgxCx4
fg_box_logits: nfgxCx4
or nfgx1x4 if class agnostic
Returns:
label_loss, box_loss
...
...
@@ -138,9 +141,12 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
fg_inds
=
tf
.
where
(
labels
>
0
)[:,
0
]
fg_labels
=
tf
.
gather
(
labels
,
fg_inds
)
num_fg
=
tf
.
size
(
fg_inds
,
out_type
=
tf
.
int64
)
if
int
(
fg_box_logits
.
shape
[
1
])
>
1
:
indices
=
tf
.
stack
(
[
tf
.
range
(
num_fg
),
fg_labels
],
axis
=
1
)
# #fgx2
fg_box_logits
=
tf
.
gather_nd
(
fg_box_logits
,
indices
)
else
:
fg_box_logits
=
tf
.
reshape
(
fg_box_logits
,
[
-
1
,
4
])
with
tf
.
name_scope
(
'label_metrics'
),
tf
.
device
(
'/cpu:0'
):
prediction
=
tf
.
argmax
(
label_logits
,
axis
=
1
,
name
=
'label_prediction'
)
...
...
@@ -229,24 +235,23 @@ FastRCNN heads for FPN:
@
layer_register
(
log_shape
=
True
)
def
fastrcnn_2fc_head
(
feature
,
num_classes
):
def
fastrcnn_2fc_head
(
feature
):
"""
Args:
feature (any shape):
num_classes(int): num_category + 1
Returns:
outputs of `fastrcnn_outputs()`
2D head feature
"""
dim
=
cfg
.
FPN
.
FRCNN_FC_HEAD_DIM
init
=
tf
.
variance_scaling_initializer
()
hidden
=
FullyConnected
(
'fc6'
,
feature
,
dim
,
kernel_initializer
=
init
,
activation
=
tf
.
nn
.
relu
)
hidden
=
FullyConnected
(
'fc7'
,
hidden
,
dim
,
kernel_initializer
=
init
,
activation
=
tf
.
nn
.
relu
)
return
fastrcnn_outputs
(
'outputs'
,
hidden
,
num_classes
)
return
hidden
@
layer_register
(
log_shape
=
True
)
def
fastrcnn_Xconv1fc_head
(
feature
,
num_c
lasses
,
num_c
onvs
,
norm
=
None
):
def
fastrcnn_Xconv1fc_head
(
feature
,
num_convs
,
norm
=
None
):
"""
Args:
feature (NCHW):
...
...
@@ -255,7 +260,7 @@ def fastrcnn_Xconv1fc_head(feature, num_classes, num_convs, norm=None):
norm (str or None): either None or 'GN'
Returns:
outputs of `fastrcnn_outputs()`
2D head feature
"""
assert
norm
in
[
None
,
'GN'
],
norm
l
=
feature
...
...
@@ -268,7 +273,7 @@ def fastrcnn_Xconv1fc_head(feature, num_classes, num_convs, norm=None):
l
=
GroupNorm
(
'gn{}'
.
format
(
k
),
l
)
l
=
FullyConnected
(
'fc'
,
l
,
cfg
.
FPN
.
FRCNN_FC_HEAD_DIM
,
kernel_initializer
=
tf
.
variance_scaling_initializer
(),
activation
=
tf
.
nn
.
relu
)
return
fastrcnn_outputs
(
'outputs'
,
l
,
num_classes
)
return
l
def
fastrcnn_4conv1fc_head
(
*
args
,
**
kwargs
):
...
...
@@ -288,10 +293,10 @@ class FastRCNNHead(object):
"""
Args:
input_boxes: Nx4, inputs to the head
box_logits: Nx#classx4, the output of the head
box_logits: Nx#classx4
or Nx1x4
, the output of the head
label_logits: Nx#class, the output of the head
bbox_regression_weights: a 4 element tensor
labels: N, each in [0, #class
-1]
, the true label for each input box
labels: N, each in [0, #class
)
, the true label for each input box
matched_gt_boxes_per_fg: #fgx4, the matching gt boxes for each fg input box
The last two arguments could be None when not training.
...
...
@@ -299,10 +304,12 @@ class FastRCNNHead(object):
for
k
,
v
in
locals
()
.
items
():
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
self
.
_bbox_class_agnostic
=
int
(
box_logits
.
shape
[
1
])
==
1
@
memoized
def
fg_inds_in_inputs
(
self
):
""" Returns: #fg indices in [0, N-1] """
assert
self
.
labels
is
not
None
return
tf
.
reshape
(
tf
.
where
(
self
.
labels
>
0
),
[
-
1
],
name
=
'fg_inds_in_inputs'
)
@
memoized
...
...
@@ -312,7 +319,7 @@ class FastRCNNHead(object):
@
memoized
def
fg_box_logits
(
self
):
""" Returns: #fg x
#class
x 4 """
""" Returns: #fg x
?
x 4 """
return
tf
.
gather
(
self
.
box_logits
,
self
.
fg_inds_in_inputs
(),
name
=
'fg_box_logits'
)
@
memoized
...
...
@@ -344,9 +351,20 @@ class FastRCNNHead(object):
@
memoized
def
decoded_output_boxes_for_true_label
(
self
):
""" Returns: Nx4 decoded boxes """
assert
self
.
labels
is
not
None
return
self
.
_decoded_output_boxes_for_label
(
self
.
labels
)
@
memoized
def
decoded_output_boxes_for_predicted_label
(
self
):
""" Returns: Nx4 decoded boxes """
return
self
.
_decoded_output_boxes_for_label
(
self
.
predicted_labels
())
@
memoized
def
decoded_output_boxes_for_label
(
self
,
labels
):
assert
not
self
.
_bbox_class_agnostic
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
self
.
labels
,
out_type
=
tf
.
int64
)),
self
.
labels
labels
])
needed_logits
=
tf
.
gather_nd
(
self
.
box_logits
,
indices
)
decoded
=
decode_bbox_target
(
...
...
@@ -355,7 +373,22 @@ class FastRCNNHead(object):
)
return
decoded
@
memoized
def
decoded_output_boxes_class_agnostic
(
self
):
assert
self
.
_bbox_class_agnostic
box_logits
=
tf
.
reshape
(
self
.
box_logits
,
[
-
1
,
4
])
decoded
=
decode_bbox_target
(
box_logits
/
self
.
bbox_regression_weights
,
self
.
input_boxes
)
return
decoded
@
memoized
def
output_scores
(
self
,
name
=
None
):
""" Returns: N x #class scores, summed to one for each box."""
return
tf
.
nn
.
softmax
(
self
.
label_logits
,
name
=
name
)
@
memoized
def
predicted_labels
(
self
):
""" Returns: N ints """
return
tf
.
argmax
(
self
.
label_logits
,
axis
=
1
,
name
=
'predicted_labels'
)
examples/FasterRCNN/train.py
View file @
cf97218c
...
...
@@ -290,8 +290,9 @@ class ResNetFPNModel(DetectionModel):
roi_feature_fastrcnn
=
multilevel_roi_align
(
p23456
[:
4
],
rcnn_boxes
,
7
)
fastrcnn_head_func
=
getattr
(
model_frcnn
,
cfg
.
FPN
.
FRCNN_HEAD_FUNC
)
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head_func
(
'fastrcnn'
,
roi_feature_fastrcnn
,
cfg
.
DATA
.
NUM_CLASS
)
head_feature
=
fastrcnn_head_func
(
'fastrcnn'
,
roi_feature_fastrcnn
)
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_outputs
(
'fastrcnn/outputs'
,
head_feature
,
cfg
.
DATA
.
NUM_CLASS
)
fastrcnn_head
=
FastRCNNHead
(
rcnn_boxes
,
fastrcnn_box_logits
,
fastrcnn_label_logits
,
tf
.
constant
(
cfg
.
FRCNN
.
BBOX_REG_WEIGHTS
,
dtype
=
tf
.
float32
),
rcnn_labels
,
matched_gt_boxes
)
...
...
tensorpack/models/batch_norm.py
View file @
cf97218c
...
...
@@ -221,16 +221,16 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
batch_mean_square
=
tf
.
reduce_mean
(
tf
.
square
(
inputs
),
axis
=
red_axis
)
if
sync_statistics
==
'nccl'
:
if
six
.
PY3
and
TF_version
<=
(
1
,
9
)
and
ctx
.
is_main_training_tower
:
logger
.
warn
(
"A bug in TensorFlow<=1.9 will cause cross-GPU BatchNorm to fail. "
"Upgrade or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
)
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
shared_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
tf
.
get_variable_scope
()
.
name
)
num_dev
=
ctx
.
total
if
num_dev
==
1
:
logger
.
warn
(
"BatchNorm(sync_statistics='nccl') is used with only one tower!"
)
else
:
assert
six
.
PY2
or
TF_version
>=
(
1
,
10
),
\
"Cross-GPU BatchNorm is only supported in TF>=1.10 ."
\
"Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"
from
tensorflow.contrib.nccl.ops
import
gen_nccl_ops
shared_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
tf
.
get_variable_scope
()
.
name
)
batch_mean
=
gen_nccl_ops
.
nccl_all_reduce
(
input
=
batch_mean
,
reduction
=
'sum'
,
...
...
@@ -243,13 +243,14 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
shared_name
=
shared_name
+
'_NCCL_mean_square'
)
*
(
1.0
/
num_dev
)
elif
sync_statistics
==
'horovod'
:
# Require https://github.com/uber/horovod/pull/331
import
horovod
hvd_version
=
tuple
(
map
(
int
,
horovod
.
__version__
.
split
(
'.'
)))
assert
hvd_version
>=
(
0
,
13
,
6
),
"sync_statistics needs horovod>=0.13.6 !"
import
horovod.tensorflow
as
hvd
if
hvd
.
size
()
==
1
:
logger
.
warn
(
"BatchNorm(sync_statistics='horovod') is used with only one process!"
)
else
:
import
horovod
hvd_version
=
tuple
(
map
(
int
,
horovod
.
__version__
.
split
(
'.'
)))
assert
hvd_version
>=
(
0
,
13
,
6
),
"sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean
=
hvd
.
allreduce
(
batch_mean
,
average
=
True
)
batch_mean_square
=
hvd
.
allreduce
(
batch_mean_square
,
average
=
True
)
batch_var
=
batch_mean_square
-
tf
.
square
(
batch_mean
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment