Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
acfb57c2
Commit
acfb57c2
authored
Nov 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FasterRCNN] make rpn_head and fastrcnn_head layers
parent
f3c50d39
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
23 deletions
+30
-23
examples/FasterRCNN/model.py
examples/FasterRCNN/model.py
+14
-13
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+12
-9
tensorpack/callbacks/prof.py
tensorpack/callbacks/prof.py
+4
-1
No files found.
examples/FasterRCNN/model.py
View file @
acfb57c2
...
...
@@ -8,21 +8,22 @@ from tensorpack.tfutils import get_current_tower_context
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.models
import
Conv2D
,
FullyConnected
,
GlobalAvgPooling
from
tensorpack.models
import
(
Conv2D
,
FullyConnected
,
GlobalAvgPooling
,
layer_register
)
from
utils.box_ops
import
pairwise_iou
import
config
@
layer_register
(
log_shape
=
True
)
def
rpn_head
(
featuremap
,
channel
,
num_anchors
):
"""
Returns:
label_logits: fHxfWxNA
box_logits: fHxfWxNAx4
"""
with
tf
.
variable_scope
(
'rpn'
),
\
argscope
(
Conv2D
,
data_format
=
'NCHW'
,
W_init
=
tf
.
random_normal_initializer
(
stddev
=
0.01
)):
with
argscope
(
Conv2D
,
data_format
=
'NCHW'
,
W_init
=
tf
.
random_normal_initializer
(
stddev
=
0.01
)):
hidden
=
Conv2D
(
'conv0'
,
featuremap
,
channel
,
3
,
nl
=
tf
.
nn
.
relu
)
label_logits
=
Conv2D
(
'class'
,
hidden
,
num_anchors
,
1
)
...
...
@@ -371,6 +372,7 @@ def roi_align(featuremap, boxes, output_shape):
return
ret
@
layer_register
(
log_shape
=
True
)
def
fastrcnn_head
(
feature
,
num_classes
):
"""
Args:
...
...
@@ -381,15 +383,14 @@ def fastrcnn_head(feature, num_classes):
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
feature
=
GlobalAvgPooling
(
'gap'
,
feature
,
data_format
=
'NCHW'
)
with
tf
.
variable_scope
(
'fastrcnn'
):
classification
=
FullyConnected
(
'class'
,
feature
,
num_classes
,
W_init
=
tf
.
random_normal_initializer
(
stddev
=
0.01
))
box_regression
=
FullyConnected
(
'box'
,
feature
,
(
num_classes
-
1
)
*
4
,
W_init
=
tf
.
random_normal_initializer
(
stddev
=
0.001
))
box_regression
=
tf
.
reshape
(
box_regression
,
(
-
1
,
num_classes
-
1
,
4
))
return
classification
,
box_regression
classification
=
FullyConnected
(
'class'
,
feature
,
num_classes
,
W_init
=
tf
.
random_normal_initializer
(
stddev
=
0.01
))
box_regression
=
FullyConnected
(
'box'
,
feature
,
(
num_classes
-
1
)
*
4
,
W_init
=
tf
.
random_normal_initializer
(
stddev
=
0.001
))
box_regression
=
tf
.
reshape
(
box_regression
,
(
-
1
,
num_classes
-
1
,
4
))
return
classification
,
box_regression
@
under_name_scope
()
...
...
examples/FasterRCNN/train.py
View file @
acfb57c2
...
...
@@ -88,7 +88,7 @@ class Model(ModelDesc):
anchor_boxes_encoded
=
encode_bbox_target
(
anchor_boxes
,
fm_anchors
)
featuremap
=
pretrained_resnet_conv4
(
image
,
config
.
RESNET_NUM_BLOCK
[:
3
])
rpn_label_logits
,
rpn_box_logits
=
rpn_head
(
featuremap
,
1024
,
config
.
NUM_ANCHOR
)
rpn_label_logits
,
rpn_box_logits
=
rpn_head
(
'rpn'
,
featuremap
,
1024
,
config
.
NUM_ANCHOR
)
rpn_label_loss
,
rpn_box_loss
=
rpn_losses
(
anchor_labels
,
anchor_boxes_encoded
,
rpn_label_logits
,
rpn_box_logits
)
...
...
@@ -99,13 +99,19 @@ class Model(ModelDesc):
tf
.
shape
(
image
)[
2
:])
if
is_training
:
# sample proposal boxes in training
rcnn_sampled_boxes
,
rcnn_encoded_boxes
,
rcnn_labels
=
sample_fast_rcnn_targets
(
proposal_boxes
,
gt_boxes
,
gt_labels
)
boxes_on_featuremap
=
rcnn_sampled_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
)
roi_resized
=
roi_align
(
featuremap
,
boxes_on_featuremap
,
14
)
feature_fastrcnn
=
resnet_conv5
(
roi_resized
,
config
.
RESNET_NUM_BLOCK
[
-
1
])
# nxc
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
feature_fastrcnn
,
config
.
NUM_CLASS
)
else
:
# use all proposal boxes in inference
boxes_on_featuremap
=
proposal_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
)
roi_resized
=
roi_align
(
featuremap
,
boxes_on_featuremap
,
14
)
feature_fastrcnn
=
resnet_conv5
(
roi_resized
,
config
.
RESNET_NUM_BLOCK
[
-
1
])
# nxc
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
'fastrcnn'
,
feature_fastrcnn
,
config
.
NUM_CLASS
)
if
is_training
:
fastrcnn_label_loss
,
fastrcnn_box_loss
=
fastrcnn_losses
(
rcnn_labels
,
rcnn_encoded_boxes
,
fastrcnn_label_logits
,
fastrcnn_box_logits
)
...
...
@@ -121,11 +127,8 @@ class Model(ModelDesc):
for
k
in
self
.
cost
,
wd_cost
:
add_moving_summary
(
k
)
else
:
roi_resized
=
roi_align
(
featuremap
,
proposal_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
),
14
)
feature_fastrcnn
=
resnet_conv5
(
roi_resized
,
config
.
RESNET_NUM_BLOCK
[
-
1
])
# nxc
label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
feature_fastrcnn
,
config
.
NUM_CLASS
)
label_probs
=
tf
.
nn
.
softmax
(
label_logits
,
name
=
'fastrcnn_all_probs'
)
# NP,
labels
=
tf
.
argmax
(
label_logits
,
axis
=
1
)
label_probs
=
tf
.
nn
.
softmax
(
fastrcnn_label_logits
,
name
=
'fastrcnn_all_probs'
)
# NP,
labels
=
tf
.
argmax
(
fastrcnn_label_logits
,
axis
=
1
)
fg_ind
,
fg_box_logits
=
fastrcnn_predict_boxes
(
labels
,
fastrcnn_box_logits
)
fg_label_probs
=
tf
.
gather
(
label_probs
,
fg_ind
,
name
=
'fastrcnn_fg_probs'
)
fg_boxes
=
tf
.
gather
(
proposal_boxes
,
fg_ind
)
...
...
tensorpack/callbacks/prof.py
View file @
acfb57c2
...
...
@@ -39,7 +39,10 @@ class GPUUtilizationTracker(Callback):
"Will monitor all visible GPUs!"
)
self
.
_devices
=
list
(
map
(
str
,
range
(
get_nr_gpu
())))
else
:
self
.
_devices
=
env
.
split
(
','
)
if
len
(
env
):
self
.
_devices
=
env
.
split
(
','
)
else
:
self
.
_devices
=
[]
else
:
self
.
_devices
=
list
(
map
(
str
,
devices
))
assert
len
(
self
.
_devices
),
"[GPUUtilizationTracker] No GPU device given!"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment