Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1f881fcf
Commit
1f881fcf
authored
Nov 09, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FasterRCNN] also support deeper resnet
parent
99d99e7d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
7 deletions
+10
-7
examples/FasterRCNN/basemodel.py
examples/FasterRCNN/basemodel.py
+2
-2
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+3
-0
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+4
-5
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+1
-0
No files found.
examples/FasterRCNN/basemodel.py
View file @
1f881fcf
...
@@ -91,11 +91,11 @@ def pretrained_resnet_conv4(image, num_blocks):
...
@@ -91,11 +91,11 @@ def pretrained_resnet_conv4(image, num_blocks):
return
l
return
l
def
resnet_conv5
(
image
):
def
resnet_conv5
_gap
(
image
,
num_block
):
with
argscope
([
Conv2D
,
GlobalAvgPooling
,
BatchNorm
],
data_format
=
'NCHW'
),
\
with
argscope
([
Conv2D
,
GlobalAvgPooling
,
BatchNorm
],
data_format
=
'NCHW'
),
\
argscope
(
Conv2D
,
nl
=
tf
.
identity
,
use_bias
=
False
),
\
argscope
(
Conv2D
,
nl
=
tf
.
identity
,
use_bias
=
False
),
\
argscope
(
BatchNorm
,
use_local_stat
=
False
):
argscope
(
BatchNorm
,
use_local_stat
=
False
):
# 14x14:
# 14x14:
l
=
resnet_group
(
image
,
'group3'
,
resnet_bottleneck
,
512
,
3
,
stride
=
2
)
l
=
resnet_group
(
image
,
'group3'
,
resnet_bottleneck
,
512
,
num_block
,
stride
=
2
)
l
=
GlobalAvgPooling
(
'gap'
,
l
)
l
=
GlobalAvgPooling
(
'gap'
,
l
)
return
l
return
l
examples/FasterRCNN/config.py
View file @
1f881fcf
...
@@ -10,6 +10,9 @@ TRAIN_DATASET = ['train2014', 'valminusminival2014']
...
@@ -10,6 +10,9 @@ TRAIN_DATASET = ['train2014', 'valminusminival2014']
VAL_DATASET
=
'minival2014'
# only support evaluation on one dataset
VAL_DATASET
=
'minival2014'
# only support evaluation on one dataset
NUM_CLASS
=
81
NUM_CLASS
=
81
# basemodel ----------------------
RESNET_NUM_BLOCK
=
[
3
,
4
,
6
,
3
]
# resnet50
# preprocessing --------------------
# preprocessing --------------------
SHORT_EDGE_SIZE
=
600
SHORT_EDGE_SIZE
=
600
MAX_SIZE
=
1024
MAX_SIZE
=
1024
...
...
examples/FasterRCNN/train.py
View file @
1f881fcf
...
@@ -25,7 +25,7 @@ from tensorpack.utils.gpu import get_nr_gpu
...
@@ -25,7 +25,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from
coco
import
COCODetection
from
coco
import
COCODetection
from
basemodel
import
(
from
basemodel
import
(
image_preprocess
,
pretrained_resnet_conv4
,
resnet_conv5
)
image_preprocess
,
pretrained_resnet_conv4
,
resnet_conv5
_gap
)
from
model
import
(
from
model
import
(
rpn_head
,
rpn_losses
,
rpn_head
,
rpn_losses
,
decode_bbox_target
,
encode_bbox_target
,
decode_bbox_target
,
encode_bbox_target
,
...
@@ -87,8 +87,7 @@ class Model(ModelDesc):
...
@@ -87,8 +87,7 @@ class Model(ModelDesc):
fm_anchors
=
self
.
_get_anchors
(
image
)
fm_anchors
=
self
.
_get_anchors
(
image
)
anchor_boxes_encoded
=
encode_bbox_target
(
anchor_boxes
,
fm_anchors
)
anchor_boxes_encoded
=
encode_bbox_target
(
anchor_boxes
,
fm_anchors
)
# resnet50
featuremap
=
pretrained_resnet_conv4
(
image
,
config
.
RESNET_NUM_BLOCK
[:
3
])
featuremap
=
pretrained_resnet_conv4
(
image
,
[
3
,
4
,
6
])
rpn_label_logits
,
rpn_box_logits
=
rpn_head
(
featuremap
,
1024
,
config
.
NR_ANCHOR
)
rpn_label_logits
,
rpn_box_logits
=
rpn_head
(
featuremap
,
1024
,
config
.
NR_ANCHOR
)
rpn_label_loss
,
rpn_box_loss
=
rpn_losses
(
rpn_label_loss
,
rpn_box_loss
=
rpn_losses
(
anchor_labels
,
anchor_boxes_encoded
,
rpn_label_logits
,
rpn_box_logits
)
anchor_labels
,
anchor_boxes_encoded
,
rpn_label_logits
,
rpn_box_logits
)
...
@@ -104,7 +103,7 @@ class Model(ModelDesc):
...
@@ -104,7 +103,7 @@ class Model(ModelDesc):
proposal_boxes
,
gt_boxes
,
gt_labels
)
proposal_boxes
,
gt_boxes
,
gt_labels
)
boxes_on_featuremap
=
rcnn_sampled_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
)
boxes_on_featuremap
=
rcnn_sampled_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
)
roi_resized
=
roi_align
(
featuremap
,
boxes_on_featuremap
,
14
)
roi_resized
=
roi_align
(
featuremap
,
boxes_on_featuremap
,
14
)
feature_fastrcnn
=
resnet_conv5
(
roi_resized
)
# nxc
feature_fastrcnn
=
resnet_conv5
_gap
(
roi_resized
,
config
.
RESNET_NUM_BLOCK
[
-
1
]
)
# nxc
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
feature_fastrcnn
,
config
.
NUM_CLASS
)
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
feature_fastrcnn
,
config
.
NUM_CLASS
)
fastrcnn_label_loss
,
fastrcnn_box_loss
=
fastrcnn_losses
(
fastrcnn_label_loss
,
fastrcnn_box_loss
=
fastrcnn_losses
(
...
@@ -123,7 +122,7 @@ class Model(ModelDesc):
...
@@ -123,7 +122,7 @@ class Model(ModelDesc):
add_moving_summary
(
k
)
add_moving_summary
(
k
)
else
:
else
:
roi_resized
=
roi_align
(
featuremap
,
proposal_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
),
14
)
roi_resized
=
roi_align
(
featuremap
,
proposal_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
),
14
)
feature_fastrcnn
=
resnet_conv5
(
roi_resized
)
# nxc
feature_fastrcnn
=
resnet_conv5
_gap
(
roi_resized
,
config
.
RESNET_NUM_BLOCK
[
-
1
]
)
# nxc
label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
feature_fastrcnn
,
config
.
NUM_CLASS
)
label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
feature_fastrcnn
,
config
.
NUM_CLASS
)
label_probs
=
tf
.
nn
.
softmax
(
label_logits
,
name
=
'fastrcnn_all_probs'
)
# NP,
label_probs
=
tf
.
nn
.
softmax
(
label_logits
,
name
=
'fastrcnn_all_probs'
)
# NP,
labels
=
tf
.
argmax
(
label_logits
,
axis
=
1
)
labels
=
tf
.
argmax
(
label_logits
,
axis
=
1
)
...
...
tensorpack/tfutils/sessinit.py
View file @
1f881fcf
...
@@ -273,4 +273,5 @@ def TryResumeTraining():
...
@@ -273,4 +273,5 @@ def TryResumeTraining():
path
=
os
.
path
.
join
(
logger
.
get_logger_dir
(),
'checkpoint'
)
path
=
os
.
path
.
join
(
logger
.
get_logger_dir
(),
'checkpoint'
)
if
not
tf
.
gfile
.
Exists
(
path
):
if
not
tf
.
gfile
.
Exists
(
path
):
return
JustCurrentSession
()
return
JustCurrentSession
()
logger
.
info
(
"Found checkpoint at {}."
.
format
(
path
))
return
SaverRestore
(
path
)
return
SaverRestore
(
path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment