Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bffcfc1b
Commit
bffcfc1b
authored
Jun 28, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] re-organize configs
parent
d4c5c4f4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
79 additions
and
64 deletions
+79
-64
examples/FasterRCNN/basemodel.py
examples/FasterRCNN/basemodel.py
+3
-5
examples/FasterRCNN/coco.py
examples/FasterRCNN/coco.py
+0
-1
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+63
-14
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+0
-4
examples/FasterRCNN/model.py
examples/FasterRCNN/model.py
+4
-4
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+9
-36
No files found.
examples/FasterRCNN/basemodel.py
View file @
bffcfc1b
...
...
@@ -40,10 +40,9 @@ def image_preprocess(image, bgr=True):
with
tf
.
name_scope
(
'image_preprocess'
):
if
image
.
dtype
.
base_dtype
!=
tf
.
float32
:
image
=
tf
.
cast
(
image
,
tf
.
float32
)
image
=
image
*
(
1.0
/
255
)
mean
=
[
0.485
,
0.456
,
0.406
]
# rgb
std
=
[
0.229
,
0.224
,
0.225
]
mean
=
cfg
.
PREPROC
.
PIXEL_MEAN
std
=
cfg
.
PREPROC
.
PIXEL_STD
if
bgr
:
mean
=
mean
[::
-
1
]
std
=
std
[::
-
1
]
...
...
@@ -93,8 +92,7 @@ def resnet_group(name, l, block_func, features, count, stride):
with
tf
.
variable_scope
(
name
):
for
i
in
range
(
0
,
count
):
with
tf
.
variable_scope
(
'block{}'
.
format
(
i
)):
l
=
block_func
(
l
,
features
,
stride
if
i
==
0
else
1
)
l
=
block_func
(
l
,
features
,
stride
if
i
==
0
else
1
)
return
l
...
...
examples/FasterRCNN/coco.py
View file @
bffcfc1b
...
...
@@ -18,7 +18,6 @@ from config import config as cfg
__all__
=
[
'COCODetection'
,
'COCOMeta'
]
COCO_NUM_CATEGORY
=
80
cfg
.
DATA
.
NUM_CLASS
=
COCO_NUM_CATEGORY
+
1
class
_COCOMeta
(
object
):
...
...
examples/FasterRCNN/config.py
View file @
bffcfc1b
# -*- coding: utf-8 -*-
# File: config.py
import
numpy
as
np
import
os
import
pprint
from
tensorpack.utils
import
logger
from
tensorpack.utils.gpu
import
get_num_gpu
__all__
=
[
'config'
]
__all__
=
[
'config'
,
'finalize_configs'
]
class
AttrDict
():
...
...
@@ -52,7 +56,7 @@ _C.MODE_FPN = False
_C
.
DATA
.
BASEDIR
=
'/path/to/your/COCO/DIR'
_C
.
DATA
.
TRAIN
=
[
'train2014'
,
'valminusminival2014'
]
# i.e., trainval35k
_C
.
DATA
.
VAL
=
'minival2014'
# For now, only support evaluation on single dataset
_C
.
DATA
.
NUM_C
LASS
=
81
# 1 background +
80 categories
_C
.
DATA
.
NUM_C
ATEGORY
=
80
#
80 categories
_C
.
DATA
.
CLASS_NAMES
=
[]
# NUM_CLASS strings. Needs to be populated later by data loader
# basemodel ----------------------
...
...
@@ -60,9 +64,10 @@ _C.BACKBONE.RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101
_C
.
BACKBONE
.
FREEZE_AFFINE
=
False
# do not train affine parameters inside BN
# Use a base model with TF-preferred pad mode,
# Use a base model with TF-preferred pad
ding
mode,
# which may pad more pixels on right/bottom than top/left.
# TF_PAD_MODE=False is better for performance but will require a different base model.
# TF_PAD_MODE=False is better for accuracy but will require a different base model.
# We will eventually switch to TF_PAD_MODE=False.
# See https://github.com/tensorflow/tensorflow/issues/18213
_C
.
BACKBONE
.
TF_PAD_MODE
=
True
...
...
@@ -79,15 +84,18 @@ _C.TRAIN.STEPS_PER_EPOCH = 500
_C
.
TRAIN
.
LR_SCHEDULE
=
[
240000
,
320000
,
360000
]
# "2x" schedule in detectron
# preprocessing --------------------
# Alternative old (worse & faster) setting: 600, 1024
_C
.
PREPROC
.
SHORT_EDGE_SIZE
=
800
_C
.
PREPROC
.
MAX_SIZE
=
1333
# Alternative old (worse & faster) setting: 600, 1024
# mean and std in RGB order.
# Un-scaled version: [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
_C
.
PREPROC
.
PIXEL_MEAN
=
[
123.675
,
116.28
,
103.53
]
_C
.
PREPROC
.
PIXEL_STD
=
[
58.395
,
57.12
,
57.375
]
# anchors -------------------------
_C
.
RPN
.
ANCHOR_STRIDE
=
16
_C
.
RPN
.
ANCHOR_SIZES
=
(
32
,
64
,
128
,
256
,
512
)
# sqrtarea of the anchor box
_C
.
RPN
.
ANCHOR_RATIOS
=
(
0.5
,
1.
,
2.
)
_C
.
RPN
.
NUM_ANCHOR
=
len
(
_C
.
RPN
.
ANCHOR_SIZES
)
*
len
(
_C
.
RPN
.
ANCHOR_RATIOS
)
_C
.
RPN
.
POSITIVE_ANCHOR_THRES
=
0.7
_C
.
RPN
.
NEGATIVE_ANCHOR_THRES
=
0.3
...
...
@@ -96,9 +104,17 @@ _C.RPN.FG_RATIO = 0.5 # fg ratio among selected RPN anchors
_C
.
RPN
.
BATCH_PER_IM
=
256
# total (across FPN levels) number of anchors that are marked valid
_C
.
RPN
.
MIN_SIZE
=
0
_C
.
RPN
.
PROPOSAL_NMS_THRESH
=
0.7
_C
.
RPN
.
CROWD_OVERLAP_THRES
=
0.7
# boxes overlapping crowd will be ignored.
# RPN proposal selection -------------------------------
# for C4
_C
.
RPN
.
TRAIN_PRE_NMS_TOPK
=
12000
_C
.
RPN
.
TRAIN_POST_NMS_TOPK
=
2000
_C
.
RPN
.
CROWD_OVERLAP_THRES
=
0.7
# boxes overlapping crowd will be ignored.
_C
.
RPN
.
TEST_PRE_NMS_TOPK
=
6000
_C
.
RPN
.
TEST_POST_NMS_TOPK
=
1000
# if you encounter OOM in inference, set this to a smaller number
# for FPN, pre/post are (for now) the same
_C
.
RPN
.
TRAIN_FPN_NMS_TOPK
=
2000
_C
.
RPN
.
TEST_FPN_NMS_TOPK
=
1000
# fastrcnn training ---------------------
_C
.
FRCNN
.
BATCH_PER_IM
=
512
...
...
@@ -108,8 +124,6 @@ _C.FRCNN.FG_RATIO = 0.25 # fg ratio in a ROI batch
# FPN -------------------------
_C
.
FPN
.
ANCHOR_STRIDES
=
(
4
,
8
,
16
,
32
,
64
)
# strides for each FPN level. Must be the same length as ANCHOR_SIZES
_C
.
FPN
.
RESOLUTION_REQUIREMENT
=
32
# image size into the backbone has to be multiple of this number
_C
.
FPN
.
NUM_CHANNEL
=
256
# conv head and fc head are only used in FPN.
# For C4 models, the head is C5
...
...
@@ -117,16 +131,51 @@ _C.FPN.FRCNN_HEAD_FUNC = 'fastrcnn_2fc_head' # choices: fastrcnn_2fc_head, fast
_C
.
FPN
.
FRCNN_CONV_HEAD_DIM
=
256
_C
.
FPN
.
FRCNN_FC_HEAD_DIM
=
1024
_C
.
RPN
.
TRAIN_FPN_NMS_TOPK
=
2000
_C
.
RPN
.
TEST_FPN_NMS_TOPK
=
1000
# Mask-RCNN
_C
.
MRCNN
.
HEAD_DIM
=
256
# testing -----------------------
_C
.
RPN
.
TEST_PRE_NMS_TOPK
=
6000
_C
.
RPN
.
TEST_POST_NMS_TOPK
=
1000
# if you encounter OOM in inference, set this to a smaller number
_C
.
TEST
.
FRCNN_NMS_THRESH
=
0.5
_C
.
TEST
.
RESULT_SCORE_THRESH
=
0.05
_C
.
TEST
.
RESULT_SCORE_THRESH_VIS
=
0.3
# only visualize confident results
_C
.
TEST
.
RESULTS_PER_IM
=
100
def
finalize_configs
(
is_training
):
"""
Run some sanity checks, and populate some configs from others
"""
_C
.
DATA
.
NUM_CLASS
=
_C
.
DATA
.
NUM_CATEGORY
+
1
# +1 background
_C
.
RPN
.
NUM_ANCHOR
=
len
(
_C
.
RPN
.
ANCHOR_SIZES
)
*
len
(
_C
.
RPN
.
ANCHOR_RATIOS
)
assert
len
(
_C
.
FPN
.
ANCHOR_STRIDES
)
==
len
(
_C
.
RPN
.
ANCHOR_SIZES
)
# image size into the backbone has to be multiple of this number
_C
.
FPN
.
RESOLUTION_REQUIREMENT
=
_C
.
FPN
.
ANCHOR_STRIDES
[
3
]
# [3] because we build FPN with features r2,r3,r4,r5
if
_C
.
MODE_FPN
:
size_mult
=
_C
.
FPN
.
RESOLUTION_REQUIREMENT
*
1.
_C
.
PREPROC
.
MAX_SIZE
=
np
.
ceil
(
_C
.
PREPROC
.
MAX_SIZE
/
size_mult
)
*
size_mult
if
is_training
:
os
.
environ
[
'TF_AUTOTUNE_THRESHOLD'
]
=
'1'
assert
_C
.
TRAINER
in
[
'horovod'
,
'replicated'
],
_C
.
TRAINER
# setup NUM_GPUS
if
_C
.
TRAINER
==
'horovod'
:
import
horovod.tensorflow
as
hvd
ngpu
=
hvd
.
size
()
else
:
assert
'OMPI_COMM_WORLD_SIZE'
not
in
os
.
environ
ngpu
=
get_num_gpu
()
assert
ngpu
%
8
==
0
or
8
%
ngpu
==
0
,
ngpu
if
_C
.
TRAIN
.
NUM_GPUS
is
None
:
_C
.
TRAIN
.
NUM_GPUS
=
ngpu
else
:
if
_C
.
TRAINER
==
'horovod'
:
assert
_C
.
TRAIN
.
NUM_GPUS
==
ngpu
else
:
assert
_C
.
TRAIN
.
NUM_GPUS
<=
ngpu
else
:
# autotune is too slow for inference
os
.
environ
[
'TF_CUDNN_USE_AUTOTUNE'
]
=
'0'
logger
.
info
(
"Config: ------------------------------------------
\n
"
+
str
(
_C
))
examples/FasterRCNN/data.py
View file @
bffcfc1b
...
...
@@ -54,10 +54,6 @@ def get_all_anchors(
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride)
max_size
=
cfg
.
PREPROC
.
MAX_SIZE
if
cfg
.
MODE_FPN
:
# TODO setting this in config is perhaps better
size_mult
=
cfg
.
FPN
.
RESOLUTION_REQUIREMENT
*
1.
max_size
=
np
.
ceil
(
max_size
/
size_mult
)
*
size_mult
field_size
=
int
(
np
.
ceil
(
max_size
/
stride
))
shifts
=
np
.
arange
(
0
,
field_size
)
*
stride
shift_x
,
shift_y
=
np
.
meshgrid
(
shifts
,
shifts
)
...
...
examples/FasterRCNN/model.py
View file @
bffcfc1b
...
...
@@ -356,7 +356,7 @@ def fastrcnn_predictions(boxes, probs):
boxes: n#catx4 floatbox in float32
probs: nx#class
"""
assert
boxes
.
shape
[
1
]
==
cfg
.
DATA
.
NUM_C
LASS
-
1
assert
boxes
.
shape
[
1
]
==
cfg
.
DATA
.
NUM_C
ATEGORY
assert
probs
.
shape
[
1
]
==
cfg
.
DATA
.
NUM_CLASS
boxes
=
tf
.
transpose
(
boxes
,
[
1
,
0
,
2
])
# #catxnx4
probs
=
tf
.
transpose
(
probs
[:,
1
:],
[
1
,
0
])
# #catxn
...
...
@@ -404,11 +404,11 @@ def fastrcnn_predictions(boxes, probs):
@
layer_register
(
log_shape
=
True
)
def
maskrcnn_upXconv_head
(
feature
,
num_c
lass
,
num_convs
):
def
maskrcnn_upXconv_head
(
feature
,
num_c
ategory
,
num_convs
):
"""
Args:
feature (NxCx s x s): size is 7 in C4 models and 14 in FPN models.
num_c
lasses(int): num_category + 1
num_c
ategory(int):
num_convs (int): number of convolution layers
Returns:
...
...
@@ -422,7 +422,7 @@ def maskrcnn_upXconv_head(feature, num_class, num_convs):
for
k
in
range
(
num_convs
):
l
=
Conv2D
(
'fcn{}'
.
format
(
k
),
l
,
cfg
.
MRCNN
.
HEAD_DIM
,
3
,
activation
=
tf
.
nn
.
relu
)
l
=
Conv2DTranspose
(
'deconv'
,
l
,
cfg
.
MRCNN
.
HEAD_DIM
,
2
,
strides
=
2
,
activation
=
tf
.
nn
.
relu
)
l
=
Conv2D
(
'conv'
,
l
,
num_c
lass
-
1
,
1
)
l
=
Conv2D
(
'conv'
,
l
,
num_c
ategory
,
1
)
return
l
...
...
examples/FasterRCNN/train.py
View file @
bffcfc1b
...
...
@@ -25,8 +25,6 @@ from tensorpack.tfutils.scope_utils import under_name_scope
from
tensorpack.tfutils
import
optimizer
from
tensorpack.tfutils.common
import
get_tf_version_number
import
tensorpack.utils.viz
as
tpviz
from
tensorpack.utils.gpu
import
get_num_gpu
from
coco
import
COCODetection
from
basemodel
import
(
...
...
@@ -49,7 +47,7 @@ from viz import (
draw_predictions
,
draw_final_outputs
)
from
eval
import
(
eval_coco
,
detect_one_image
,
print_evaluation_scores
,
DetectionResult
)
from
config
import
config
as
cfg
from
config
import
finalize_configs
,
config
as
cfg
class
DetectionModel
(
ModelDesc
):
...
...
@@ -131,9 +129,9 @@ class DetectionModel(ModelDesc):
labels (m): each >= 1
"""
rcnn_box_logits
=
rcnn_box_logits
[:,
1
:,
:]
rcnn_box_logits
.
set_shape
([
None
,
cfg
.
DATA
.
NUM_C
LASS
-
1
,
None
])
rcnn_box_logits
.
set_shape
([
None
,
cfg
.
DATA
.
NUM_C
ATEGORY
,
None
])
label_probs
=
tf
.
nn
.
softmax
(
rcnn_label_logits
,
name
=
'fastrcnn_all_probs'
)
# #proposal x #Class
anchors
=
tf
.
tile
(
tf
.
expand_dims
(
rcnn_boxes
,
1
),
[
1
,
cfg
.
DATA
.
NUM_C
LASS
-
1
,
1
])
# #proposal x #Cat x 4
anchors
=
tf
.
tile
(
tf
.
expand_dims
(
rcnn_boxes
,
1
),
[
1
,
cfg
.
DATA
.
NUM_C
ATEGORY
,
1
])
# #proposal x #Cat x 4
decoded_boxes
=
decode_bbox_target
(
rcnn_box_logits
/
tf
.
constant
(
cfg
.
FRCNN
.
BBOX_REG_WEIGHTS
,
dtype
=
tf
.
float32
),
anchors
)
...
...
@@ -237,7 +235,7 @@ class ResNetC4Model(DetectionModel):
# In training, mask branch shares the same C5 feature.
fg_feature
=
tf
.
gather
(
feature_fastrcnn
,
fg_inds_wrt_sample
)
mask_logits
=
maskrcnn_upXconv_head
(
'maskrcnn'
,
fg_feature
,
cfg
.
DATA
.
NUM_C
LASS
,
num_convs
=
0
)
# #fg x #cat x 14x14
'maskrcnn'
,
fg_feature
,
cfg
.
DATA
.
NUM_C
ATEGORY
,
num_convs
=
0
)
# #fg x #cat x 14x14
target_masks_for_fg
=
crop_and_resize
(
tf
.
expand_dims
(
gt_masks
,
1
),
...
...
@@ -269,7 +267,7 @@ class ResNetC4Model(DetectionModel):
roi_resized
=
roi_align
(
featuremap
,
final_boxes
*
(
1.0
/
cfg
.
RPN
.
ANCHOR_STRIDE
),
14
)
feature_maskrcnn
=
resnet_conv5
(
roi_resized
,
cfg
.
BACKBONE
.
RESNET_NUM_BLOCK
[
-
1
])
mask_logits
=
maskrcnn_upXconv_head
(
'maskrcnn'
,
feature_maskrcnn
,
cfg
.
DATA
.
NUM_C
LASS
,
0
)
# #result x #cat x 14x14
'maskrcnn'
,
feature_maskrcnn
,
cfg
.
DATA
.
NUM_C
ATEGORY
,
0
)
# #result x #cat x 14x14
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
final_labels
)),
tf
.
to_int32
(
final_labels
)
-
1
],
axis
=
1
)
final_mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #resultx14x14
tf
.
sigmoid
(
final_mask_logits
,
name
=
'final_masks'
)
...
...
@@ -393,7 +391,7 @@ class ResNetFPNModel(DetectionModel):
roi_feature_maskrcnn
=
multilevel_roi_align
(
p23456
[:
4
],
fg_sampled_boxes
,
14
)
mask_logits
=
maskrcnn_upXconv_head
(
'maskrcnn'
,
roi_feature_maskrcnn
,
cfg
.
DATA
.
NUM_C
LASS
,
4
)
# #fg x #cat x 28 x 28
'maskrcnn'
,
roi_feature_maskrcnn
,
cfg
.
DATA
.
NUM_C
ATEGORY
,
4
)
# #fg x #cat x 28 x 28
target_masks_for_fg
=
crop_and_resize
(
tf
.
expand_dims
(
gt_masks
,
1
),
...
...
@@ -422,7 +420,7 @@ class ResNetFPNModel(DetectionModel):
# Cascade inference needs roi transform with refined boxes.
roi_feature_maskrcnn
=
multilevel_roi_align
(
p23456
[:
4
],
final_boxes
,
14
)
mask_logits
=
maskrcnn_upXconv_head
(
'maskrcnn'
,
roi_feature_maskrcnn
,
cfg
.
DATA
.
NUM_C
LASS
,
4
)
# #fg x #cat x 28 x 28
'maskrcnn'
,
roi_feature_maskrcnn
,
cfg
.
DATA
.
NUM_C
ATEGORY
,
4
)
# #fg x #cat x 28 x 28
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
final_labels
)),
tf
.
to_int32
(
final_labels
)
-
1
],
axis
=
1
)
final_mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #resultx28x28
tf
.
sigmoid
(
final_mask_logits
,
name
=
'final_masks'
)
...
...
@@ -532,25 +530,6 @@ class EvalCallback(Callback):
self
.
_eval
()
def
init_config
():
"""
Initialize config for training.
"""
if
cfg
.
TRAINER
==
'horovod'
:
ngpu
=
hvd
.
size
()
else
:
ngpu
=
get_num_gpu
()
assert
ngpu
%
8
==
0
or
8
%
ngpu
==
0
,
ngpu
if
cfg
.
TRAIN
.
NUM_GPUS
is
None
:
cfg
.
TRAIN
.
NUM_GPUS
=
ngpu
else
:
if
cfg
.
TRAINER
==
'horovod'
:
assert
cfg
.
TRAIN
.
NUM_GPUS
==
ngpu
else
:
assert
cfg
.
TRAIN
.
NUM_GPUS
<=
ngpu
logger
.
info
(
"Config: ------------------------------------------
\n
"
+
str
(
cfg
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--load'
,
help
=
'load a model for evaluation or training'
)
...
...
@@ -573,11 +552,8 @@ if __name__ == '__main__':
MODEL
=
ResNetFPNModel
()
if
cfg
.
MODE_FPN
else
ResNetC4Model
()
if
args
.
visualize
or
args
.
evaluate
or
args
.
predict
:
# autotune is too slow for inference
os
.
environ
[
'TF_CUDNN_USE_AUTOTUNE'
]
=
'0'
assert
args
.
load
logger
.
info
(
"Config: ------------------------------------------
\n
"
+
str
(
cfg
)
)
finalize_configs
(
is_training
=
False
)
if
args
.
predict
or
args
.
visualize
:
cfg
.
TEST
.
RESULT_SCORE_THRESH
=
cfg
.
TEST
.
RESULT_SCORE_THRESH_VIS
...
...
@@ -598,18 +574,15 @@ if __name__ == '__main__':
COCODetection
(
cfg
.
DATA
.
BASEDIR
,
'val2014'
)
# Only to load the class names into caches
predict
(
pred
,
args
.
predict
)
else
:
os
.
environ
[
'TF_AUTOTUNE_THRESHOLD'
]
=
'1'
is_horovod
=
cfg
.
TRAINER
==
'horovod'
if
is_horovod
:
hvd
.
init
()
logger
.
info
(
"Horovod Rank={}, Size={}"
.
format
(
hvd
.
rank
(),
hvd
.
size
()))
else
:
assert
'OMPI_COMM_WORLD_SIZE'
not
in
os
.
environ
if
not
is_horovod
or
hvd
.
rank
()
==
0
:
logger
.
set_logger_dir
(
args
.
logdir
,
'd'
)
init_config
(
)
finalize_configs
(
is_training
=
True
)
factor
=
8.
/
cfg
.
TRAIN
.
NUM_GPUS
stepnum
=
cfg
.
TRAIN
.
STEPS_PER_EPOCH
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment