Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9b1b5f29
Commit
9b1b5f29
authored
May 01, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] add DatasetSplit and DatasetRegistry for generic dataset handling
parent
8908e6d4
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
296 additions
and
275 deletions
+296
-275
examples/FasterRCNN/NOTES.md
examples/FasterRCNN/NOTES.md
+20
-14
examples/FasterRCNN/coco.py
examples/FasterRCNN/coco.py
+219
-0
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+3
-4
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+10
-8
examples/FasterRCNN/dataset.py
examples/FasterRCNN/dataset.py
+34
-240
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+5
-5
examples/FasterRCNN/model_frcnn.py
examples/FasterRCNN/model_frcnn.py
+1
-1
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+4
-3
No files found.
examples/FasterRCNN/NOTES.md
View file @
9b1b5f29
### File Structure
This is a minimal implementation that simply contains these files:
+
dataset.py: load and evaluate COCO dataset
+
dataset.py: the dataset interface
+
coco.py: load COCO data to the dataset interface
+
data.py: prepare data for training & inference
+
common.py: common data preparation utilities
+
backbone.py: implement backbones
+
model_box.py: implement box-related symbolic functions
+
generalized_rcnn.py: implement variants of generalized R-CNN architecture
+
model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast/Mask/Cascade R-CNN models.
+
model_box.py: implement box-related symbolic functions
+
train.py: main entry script
+
utils/: third-party helper functions
+
eval.py: evaluation utilities
...
...
@@ -17,20 +18,25 @@ This is a minimal implementation that simply contains these files:
Data:
1.
It's easy to train on your own data by changing
`dataset.py`
.
1.
It's easy to train on your own data, by calling
`DatasetRegistry.register(name, lambda: YourDatasetSplit())`
,
and modify
`cfg.DATA.*`
accordingly.
`YourDatasetSplit` can be:
+
`COCODetection`
, if your data is already in COCO format. In this case, you need to
modify
`COCODetection`
to change the class names and the id mapping.
+
Your own class, if your data is not in COCO format.
You need to write a subclass of
`DatasetSplit`
, similar to
`COCODetection`
.
In this class you'll implement the logic to load your dataset and evaluate predictions.
The documentation is in the docstring of
`DatasetSplit.
+
If your data is in COCO format, modify
`COCODetection`
to change the class names and the id mapping.
+
If your data is not in COCO format, ignore
`COCODetection`
completely and
rewrite all the methods of
`DetectionDataset`
following its documents.
You'll implement the logic to load your dataset and evaluate predictions.
+
If you load a COCO-trained model on a different dataset, you'll see error messages
complaining about unmatched number of categories for certain weights in the checkpoint.
You can either remove those weights in checkpoint, or rename them in the model.
See
[
tensorpack tutorial
](
https://tensorpack.readthedocs.io/tutorial/save-load.html
)
for more details.
1. If you load a COCO-trained model on a different dataset, you may see error messages
complaining about unmatched number of categories for certain weights in the checkpoint.
You can either remove those weights in checkpoint, or rename them in the model.
See [tensorpack tutorial](https://tensorpack.readthedocs.io/tutorial/save-load.html) for more details.
2
.
You can easily add more augmentations such as rotation, but be careful how a box should be
1
. You can easily add more augmentations such as rotation, but be careful how a box should be
augmented. The code now will always use the minimal axis-aligned bounding box of the 4 corners,
which is probably not the optimal way.
A TODO is to generate bounding box from segmentation, so more augmentations can be naturally supported.
...
...
examples/FasterRCNN/coco.py
0 → 100644
View file @
9b1b5f29
# -*- coding: utf-8 -*-
import
numpy
as
np
import
os
import
tqdm
import
json
from
tensorpack.utils
import
logger
from
tensorpack.utils.timer
import
timed_operation
from
config
import
config
as
cfg
from
dataset
import
DatasetSplit
,
DatasetRegistry
__all__
=
[
'register_coco'
]
class
COCODetection
(
DatasetSplit
):
# handle the weird (but standard) split of train and val
_INSTANCE_TO_BASEDIR
=
{
'valminusminival2014'
:
'val2014'
,
'minival2014'
:
'val2014'
,
}
COCO_id_to_category_id
=
{
1
:
1
,
2
:
2
,
3
:
3
,
4
:
4
,
5
:
5
,
6
:
6
,
7
:
7
,
8
:
8
,
9
:
9
,
10
:
10
,
11
:
11
,
13
:
12
,
14
:
13
,
15
:
14
,
16
:
15
,
17
:
16
,
18
:
17
,
19
:
18
,
20
:
19
,
21
:
20
,
22
:
21
,
23
:
22
,
24
:
23
,
25
:
24
,
27
:
25
,
28
:
26
,
31
:
27
,
32
:
28
,
33
:
29
,
34
:
30
,
35
:
31
,
36
:
32
,
37
:
33
,
38
:
34
,
39
:
35
,
40
:
36
,
41
:
37
,
42
:
38
,
43
:
39
,
44
:
40
,
46
:
41
,
47
:
42
,
48
:
43
,
49
:
44
,
50
:
45
,
51
:
46
,
52
:
47
,
53
:
48
,
54
:
49
,
55
:
50
,
56
:
51
,
57
:
52
,
58
:
53
,
59
:
54
,
60
:
55
,
61
:
56
,
62
:
57
,
63
:
58
,
64
:
59
,
65
:
60
,
67
:
61
,
70
:
62
,
72
:
63
,
73
:
64
,
74
:
65
,
75
:
66
,
76
:
67
,
77
:
68
,
78
:
69
,
79
:
70
,
80
:
71
,
81
:
72
,
82
:
73
,
84
:
74
,
85
:
75
,
86
:
76
,
87
:
77
,
88
:
78
,
89
:
79
,
90
:
80
}
# noqa
"""
Mapping from the incontinuous COCO category id to an id in [1, #category]
For your own dataset, this should usually be an identity mapping.
"""
# 80 names for COCO
class_names
=
[
"person"
,
"bicycle"
,
"car"
,
"motorcycle"
,
"airplane"
,
"bus"
,
"train"
,
"truck"
,
"boat"
,
"traffic light"
,
"fire hydrant"
,
"stop sign"
,
"parking meter"
,
"bench"
,
"bird"
,
"cat"
,
"dog"
,
"horse"
,
"sheep"
,
"cow"
,
"elephant"
,
"bear"
,
"zebra"
,
"giraffe"
,
"backpack"
,
"umbrella"
,
"handbag"
,
"tie"
,
"suitcase"
,
"frisbee"
,
"skis"
,
"snowboard"
,
"sports ball"
,
"kite"
,
"baseball bat"
,
"baseball glove"
,
"skateboard"
,
"surfboard"
,
"tennis racket"
,
"bottle"
,
"wine glass"
,
"cup"
,
"fork"
,
"knife"
,
"spoon"
,
"bowl"
,
"banana"
,
"apple"
,
"sandwich"
,
"orange"
,
"broccoli"
,
"carrot"
,
"hot dog"
,
"pizza"
,
"donut"
,
"cake"
,
"chair"
,
"couch"
,
"potted plant"
,
"bed"
,
"dining table"
,
"toilet"
,
"tv"
,
"laptop"
,
"mouse"
,
"remote"
,
"keyboard"
,
"cell phone"
,
"microwave"
,
"oven"
,
"toaster"
,
"sink"
,
"refrigerator"
,
"book"
,
"clock"
,
"vase"
,
"scissors"
,
"teddy bear"
,
"hair drier"
,
"toothbrush"
]
# noqa
cfg
.
DATA
.
CLASS_NAMES
=
[
"BG"
]
+
class_names
def
__init__
(
self
,
basedir
,
name
):
"""
Args:
basedir (str): root to the dataset
name (str): the name of the split, e.g. "train2017"
"""
basedir
=
os
.
path
.
expanduser
(
basedir
)
self
.
name
=
name
self
.
_imgdir
=
os
.
path
.
realpath
(
os
.
path
.
join
(
basedir
,
self
.
_INSTANCE_TO_BASEDIR
.
get
(
name
,
name
)))
assert
os
.
path
.
isdir
(
self
.
_imgdir
),
self
.
_imgdir
annotation_file
=
os
.
path
.
join
(
basedir
,
'annotations/instances_{}.json'
.
format
(
name
))
assert
os
.
path
.
isfile
(
annotation_file
),
annotation_file
from
pycocotools.coco
import
COCO
self
.
coco
=
COCO
(
annotation_file
)
logger
.
info
(
"Instances loaded from {}."
.
format
(
annotation_file
))
# https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
def
print_coco_metrics
(
self
,
json_file
):
"""
Args:
json_file (str): path to the results json file in coco format
Returns:
dict: the evaluation metrics
"""
from
pycocotools.cocoeval
import
COCOeval
ret
=
{}
cocoDt
=
self
.
coco
.
loadRes
(
json_file
)
cocoEval
=
COCOeval
(
self
.
coco
,
cocoDt
,
'bbox'
)
cocoEval
.
evaluate
()
cocoEval
.
accumulate
()
cocoEval
.
summarize
()
fields
=
[
'IoU=0.5:0.95'
,
'IoU=0.5'
,
'IoU=0.75'
,
'small'
,
'medium'
,
'large'
]
for
k
in
range
(
6
):
ret
[
'mAP(bbox)/'
+
fields
[
k
]]
=
cocoEval
.
stats
[
k
]
json_obj
=
json
.
load
(
open
(
json_file
))
if
len
(
json_obj
)
>
0
and
'segmentation'
in
json_obj
[
0
]:
cocoEval
=
COCOeval
(
self
.
coco
,
cocoDt
,
'segm'
)
cocoEval
.
evaluate
()
cocoEval
.
accumulate
()
cocoEval
.
summarize
()
for
k
in
range
(
6
):
ret
[
'mAP(segm)/'
+
fields
[
k
]]
=
cocoEval
.
stats
[
k
]
return
ret
def
load
(
self
,
add_gt
=
True
,
add_mask
=
False
):
"""
Args:
add_gt: whether to add ground truth bounding box annotations to the dicts
add_mask: whether to also add ground truth mask
Returns:
a list of dict, each has keys including:
'image_id', 'file_name',
and (if add_gt is True) 'boxes', 'class', 'is_crowd', and optionally
'segmentation'.
"""
if
add_mask
:
assert
add_gt
with
timed_operation
(
'Load Groundtruth Boxes for {}'
.
format
(
self
.
name
)):
img_ids
=
self
.
coco
.
getImgIds
()
img_ids
.
sort
()
# list of dict, each has keys: height,width,id,file_name
imgs
=
self
.
coco
.
loadImgs
(
img_ids
)
for
img
in
tqdm
.
tqdm
(
imgs
):
img
[
'image_id'
]
=
img
.
pop
(
'id'
)
self
.
_use_absolute_file_name
(
img
)
if
add_gt
:
self
.
_add_detection_gt
(
img
,
add_mask
)
return
imgs
def
_use_absolute_file_name
(
self
,
img
):
"""
Change relative filename to abosolute file name.
"""
img
[
'file_name'
]
=
os
.
path
.
join
(
self
.
_imgdir
,
img
[
'file_name'
])
assert
os
.
path
.
isfile
(
img
[
'file_name'
]),
img
[
'file_name'
]
def
_add_detection_gt
(
self
,
img
,
add_mask
):
"""
Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection.
If add_mask is True, also add 'segmentation' in coco poly format.
"""
# ann_ids = self.coco.getAnnIds(imgIds=img['image_id'])
# objs = self.coco.loadAnns(ann_ids)
objs
=
self
.
coco
.
imgToAnns
[
img
[
'image_id'
]]
# equivalent but faster than the above two lines
# clean-up boxes
valid_objs
=
[]
width
=
img
.
pop
(
'width'
)
height
=
img
.
pop
(
'height'
)
for
objid
,
obj
in
enumerate
(
objs
):
if
obj
.
get
(
'ignore'
,
0
)
==
1
:
continue
x1
,
y1
,
w
,
h
=
obj
[
'bbox'
]
# bbox is originally in float
# x1/y1 means upper-left corner and w/h means true w/h. This can be verified by segmentation pixels.
# But we do make an assumption here that (0.0, 0.0) is upper-left corner of the first pixel
x1
=
np
.
clip
(
float
(
x1
),
0
,
width
)
y1
=
np
.
clip
(
float
(
y1
),
0
,
height
)
w
=
np
.
clip
(
float
(
x1
+
w
),
0
,
width
)
-
x1
h
=
np
.
clip
(
float
(
y1
+
h
),
0
,
height
)
-
y1
# Require non-zero seg area and more than 1x1 box size
if
obj
[
'area'
]
>
1
and
w
>
0
and
h
>
0
and
w
*
h
>=
4
:
obj
[
'bbox'
]
=
[
x1
,
y1
,
x1
+
w
,
y1
+
h
]
valid_objs
.
append
(
obj
)
if
add_mask
:
segs
=
obj
[
'segmentation'
]
if
not
isinstance
(
segs
,
list
):
assert
obj
[
'iscrowd'
]
==
1
obj
[
'segmentation'
]
=
None
else
:
valid_segs
=
[
np
.
asarray
(
p
)
.
reshape
(
-
1
,
2
)
.
astype
(
'float32'
)
for
p
in
segs
if
len
(
p
)
>=
6
]
if
len
(
valid_segs
)
==
0
:
logger
.
error
(
"Object {} in image {} has no valid polygons!"
.
format
(
objid
,
img
[
'file_name'
]))
elif
len
(
valid_segs
)
<
len
(
segs
):
logger
.
warn
(
"Object {} in image {} has invalid polygons!"
.
format
(
objid
,
img
[
'file_name'
]))
obj
[
'segmentation'
]
=
valid_segs
# all geometrically-valid boxes are returned
boxes
=
np
.
asarray
([
obj
[
'bbox'
]
for
obj
in
valid_objs
],
dtype
=
'float32'
)
# (n, 4)
cls
=
np
.
asarray
([
self
.
COCO_id_to_category_id
[
obj
[
'category_id'
]]
for
obj
in
valid_objs
],
dtype
=
'int32'
)
# (n,)
is_crowd
=
np
.
asarray
([
obj
[
'iscrowd'
]
for
obj
in
valid_objs
],
dtype
=
'int8'
)
# add the keys
img
[
'boxes'
]
=
boxes
# nx4
img
[
'class'
]
=
cls
# n, always >0
img
[
'is_crowd'
]
=
is_crowd
# n,
if
add_mask
:
# also required to be float32
img
[
'segmentation'
]
=
[
obj
[
'segmentation'
]
for
obj
in
valid_objs
]
def
training_roidbs
(
self
):
return
self
.
load
(
add_gt
=
True
,
add_mask
=
cfg
.
MODE_MASK
)
def
inference_roidbs
(
self
):
return
self
.
load
(
add_gt
=
False
)
def
eval_inference_results
(
self
,
results
,
output
):
continuous_id_to_COCO_id
=
{
v
:
k
for
k
,
v
in
self
.
COCO_id_to_category_id
.
items
()}
for
res
in
results
:
# convert to COCO's incontinuous category id
res
[
'category_id'
]
=
continuous_id_to_COCO_id
[
res
[
'category_id'
]]
# COCO expects results in xywh format
box
=
res
[
'bbox'
]
box
[
2
]
-=
box
[
0
]
box
[
3
]
-=
box
[
1
]
res
[
'bbox'
]
=
[
round
(
float
(
x
),
3
)
for
x
in
box
]
assert
output
is
not
None
,
"COCO evaluation requires an output file!"
with
open
(
output
,
'w'
)
as
f
:
json
.
dump
(
results
,
f
)
if
len
(
results
):
# sometimes may crash if the results are empty?
return
self
.
print_coco_metrics
(
output
)
else
:
return
{}
def
register_coco
(
basedir
):
"""
Add COCO datasets like "coco_train201x" to the registry,
so you can refer to them with names in `cfg.DATA.TRAIN/VAL`.
"""
for
split
in
[
"train2017"
,
"val2017"
,
"train2014"
,
"val2014"
,
"valminusminival2014"
,
"minival2014"
]:
DatasetRegistry
.
register
(
"coco_"
+
split
,
lambda
x
=
split
:
COCODetection
(
basedir
,
x
))
if
__name__
==
'__main__'
:
basedir
=
'~/data/coco'
c
=
COCODetection
(
basedir
,
'train2014'
)
roidb
=
c
.
load
(
add_gt
=
True
,
add_mask
=
True
)
print
(
"#Images:"
,
len
(
roidb
))
examples/FasterRCNN/config.py
View file @
9b1b5f29
...
...
@@ -85,11 +85,11 @@ _C.MODE_FPN = False
# dataset -----------------------
_C
.
DATA
.
BASEDIR
=
'/path/to/your/DATA/DIR'
# All TRAIN dataset will be concatenated for training.
_C
.
DATA
.
TRAIN
=
(
'
train2014'
,
'
valminusminival2014'
)
# i.e. trainval35k, AKA train2017
_C
.
DATA
.
TRAIN
=
(
'
coco_train2014'
,
'coco_
valminusminival2014'
)
# i.e. trainval35k, AKA train2017
# Each VAL dataset will be evaluated separately (instead of concatenated)
_C
.
DATA
.
VAL
=
(
'minival2014'
,
)
# AKA val2017
_C
.
DATA
.
VAL
=
(
'
coco_
minival2014'
,
)
# AKA val2017
# This two config will be populated later by the dataset loader:
_C
.
DATA
.
NUM_CATEGORY
=
0
# without the background class (e.g., 80 for COCO)
_C
.
DATA
.
NUM_CATEGORY
=
8
0
# without the background class (e.g., 80 for COCO)
_C
.
DATA
.
CLASS_NAMES
=
[]
# NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1]
_C
.
DATA
.
ABSOLUTE_COORD
=
True
...
...
@@ -216,7 +216,6 @@ def finalize_configs(is_training):
Run some sanity checks, and populate some configs from others
"""
_C
.
freeze
(
False
)
# populate new keys now
_C
.
DATA
.
BASEDIR
=
os
.
path
.
expanduser
(
_C
.
DATA
.
BASEDIR
)
if
isinstance
(
_C
.
DATA
.
VAL
,
six
.
string_types
):
# support single string (the typical case) as well
_C
.
DATA
.
VAL
=
(
_C
.
DATA
.
VAL
,
)
...
...
examples/FasterRCNN/data.py
View file @
9b1b5f29
...
...
@@ -4,6 +4,7 @@
import
copy
import
numpy
as
np
import
cv2
import
itertools
from
tabulate
import
tabulate
from
termcolor
import
colored
...
...
@@ -16,7 +17,7 @@ from common import (
CustomResize
,
DataFromListOfDict
,
box_to_point8
,
filter_boxes_inside_shape
,
point8_to_box
,
segmentation_to_mask
,
np_iou
)
from
config
import
config
as
cfg
from
dataset
import
D
etectionDataset
from
dataset
import
D
atasetRegistry
from
utils.generate_anchors
import
generate_anchors
from
utils.np_box_ops
import
area
as
np_area
,
ioa
as
np_ioa
...
...
@@ -30,20 +31,20 @@ class MalformedData(BaseException):
def
print_class_histogram
(
roidbs
):
"""
Args:
roidbs (list[dict]): the same format as the output of `
load_
training_roidbs`.
roidbs (list[dict]): the same format as the output of `training_roidbs`.
"""
dataset
=
DetectionDataset
()
hist_bins
=
np
.
arange
(
dataset
.
num_classes
+
1
)
# labels are in [1, NUM_CATEGORY], hence +2 for bins
hist_bins
=
np
.
arange
(
cfg
.
DATA
.
NUM_CATEGORY
+
2
)
# Histogram of ground-truth objects
gt_hist
=
np
.
zeros
((
dataset
.
num_classes
,),
dtype
=
np
.
int
)
gt_hist
=
np
.
zeros
((
cfg
.
DATA
.
NUM_CATEGORY
+
1
,),
dtype
=
np
.
int
)
for
entry
in
roidbs
:
# filter crowd?
gt_inds
=
np
.
where
(
(
entry
[
'class'
]
>
0
)
&
(
entry
[
'is_crowd'
]
==
0
))[
0
]
gt_classes
=
entry
[
'class'
][
gt_inds
]
gt_hist
+=
np
.
histogram
(
gt_classes
,
bins
=
hist_bins
)[
0
]
data
=
[[
dataset
.
class_names
[
i
],
v
]
for
i
,
v
in
enumerate
(
gt_hist
)]
data
=
[[
cfg
.
DATA
.
CLASS_NAMES
[
i
],
v
]
for
i
,
v
in
enumerate
(
gt_hist
)]
data
.
append
([
'total'
,
sum
(
x
[
1
]
for
x
in
data
)])
# the first line is BG
table
=
tabulate
(
data
[
1
:],
headers
=
[
'class'
,
'#box'
],
tablefmt
=
'pipe'
)
...
...
@@ -284,7 +285,7 @@ def get_train_dataflow():
If MODE_MASK, gt_masks: (N, h, w)
"""
roidbs
=
DetectionDataset
()
.
load_training_roidbs
(
cfg
.
DATA
.
TRAIN
)
roidbs
=
list
(
itertools
.
chain
.
from_iterable
(
DatasetRegistry
.
get
(
x
)
.
training_roidbs
()
for
x
in
cfg
.
DATA
.
TRAIN
)
)
print_class_histogram
(
roidbs
)
# Valid training images should have at least one fg box.
...
...
@@ -387,7 +388,8 @@ def get_eval_dataflow(name, shard=0, num_shards=1):
name (str): name of the dataset to evaluate
shard, num_shards: to get subset of evaluation data
"""
roidbs
=
DetectionDataset
()
.
load_inference_roidbs
(
name
)
roidbs
=
DatasetRegistry
.
get
(
name
)
.
inference_roidbs
()
logger
.
info
(
"Found {} images for inference."
.
format
(
len
(
roidbs
)))
num_imgs
=
len
(
roidbs
)
img_per_shard
=
num_imgs
//
num_shards
...
...
examples/FasterRCNN/dataset.py
View file @
9b1b5f29
# -*- coding: utf-8 -*-
# File: coco.py
__all__
=
[
'DatasetRegistry'
,
'DatasetSplit'
]
import
numpy
as
np
import
os
import
tqdm
import
json
from
tensorpack.utils
import
logger
from
tensorpack.utils.timer
import
timed_operation
from
config
import
config
as
cfg
__all__
=
[
'COCODetection'
,
'DetectionDataset'
]
class
COCODetection
:
# handle the weird (but standard) split of train and val
_INSTANCE_TO_BASEDIR
=
{
'valminusminival2014'
:
'val2014'
,
'minival2014'
:
'val2014'
,
}
COCO_id_to_category_id
=
{
1
:
1
,
2
:
2
,
3
:
3
,
4
:
4
,
5
:
5
,
6
:
6
,
7
:
7
,
8
:
8
,
9
:
9
,
10
:
10
,
11
:
11
,
13
:
12
,
14
:
13
,
15
:
14
,
16
:
15
,
17
:
16
,
18
:
17
,
19
:
18
,
20
:
19
,
21
:
20
,
22
:
21
,
23
:
22
,
24
:
23
,
25
:
24
,
27
:
25
,
28
:
26
,
31
:
27
,
32
:
28
,
33
:
29
,
34
:
30
,
35
:
31
,
36
:
32
,
37
:
33
,
38
:
34
,
39
:
35
,
40
:
36
,
41
:
37
,
42
:
38
,
43
:
39
,
44
:
40
,
46
:
41
,
47
:
42
,
48
:
43
,
49
:
44
,
50
:
45
,
51
:
46
,
52
:
47
,
53
:
48
,
54
:
49
,
55
:
50
,
56
:
51
,
57
:
52
,
58
:
53
,
59
:
54
,
60
:
55
,
61
:
56
,
62
:
57
,
63
:
58
,
64
:
59
,
65
:
60
,
67
:
61
,
70
:
62
,
72
:
63
,
73
:
64
,
74
:
65
,
75
:
66
,
76
:
67
,
77
:
68
,
78
:
69
,
79
:
70
,
80
:
71
,
81
:
72
,
82
:
73
,
84
:
74
,
85
:
75
,
86
:
76
,
87
:
77
,
88
:
78
,
89
:
79
,
90
:
80
}
# noqa
"""
Mapping from the incontinuous COCO category id to an id in [1, #category]
For your own dataset, this should usually be an identity mapping.
class
DatasetSplit
():
"""
A class to load datasets, evaluate results for a datast split (e.g., "coco_train_2017")
# 80 names for COCO
class_names
=
[
"person"
,
"bicycle"
,
"car"
,
"motorcycle"
,
"airplane"
,
"bus"
,
"train"
,
"truck"
,
"boat"
,
"traffic light"
,
"fire hydrant"
,
"stop sign"
,
"parking meter"
,
"bench"
,
"bird"
,
"cat"
,
"dog"
,
"horse"
,
"sheep"
,
"cow"
,
"elephant"
,
"bear"
,
"zebra"
,
"giraffe"
,
"backpack"
,
"umbrella"
,
"handbag"
,
"tie"
,
"suitcase"
,
"frisbee"
,
"skis"
,
"snowboard"
,
"sports ball"
,
"kite"
,
"baseball bat"
,
"baseball glove"
,
"skateboard"
,
"surfboard"
,
"tennis racket"
,
"bottle"
,
"wine glass"
,
"cup"
,
"fork"
,
"knife"
,
"spoon"
,
"bowl"
,
"banana"
,
"apple"
,
"sandwich"
,
"orange"
,
"broccoli"
,
"carrot"
,
"hot dog"
,
"pizza"
,
"donut"
,
"cake"
,
"chair"
,
"couch"
,
"potted plant"
,
"bed"
,
"dining table"
,
"toilet"
,
"tv"
,
"laptop"
,
"mouse"
,
"remote"
,
"keyboard"
,
"cell phone"
,
"microwave"
,
"oven"
,
"toaster"
,
"sink"
,
"refrigerator"
,
"book"
,
"clock"
,
"vase"
,
"scissors"
,
"teddy bear"
,
"hair drier"
,
"toothbrush"
]
# noqa
def
__init__
(
self
,
basedir
,
name
):
basedir
=
os
.
path
.
expanduser
(
basedir
)
self
.
name
=
name
self
.
_imgdir
=
os
.
path
.
realpath
(
os
.
path
.
join
(
basedir
,
self
.
_INSTANCE_TO_BASEDIR
.
get
(
name
,
name
)))
assert
os
.
path
.
isdir
(
self
.
_imgdir
),
self
.
_imgdir
annotation_file
=
os
.
path
.
join
(
basedir
,
'annotations/instances_{}.json'
.
format
(
name
))
assert
os
.
path
.
isfile
(
annotation_file
),
annotation_file
from
pycocotools.coco
import
COCO
self
.
coco
=
COCO
(
annotation_file
)
logger
.
info
(
"Instances loaded from {}."
.
format
(
annotation_file
))
# https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
def
print_coco_metrics
(
self
,
json_file
):
"""
Args:
json_file (str): path to the results json file in coco format
Returns:
dict: the evaluation metrics
"""
from
pycocotools.cocoeval
import
COCOeval
ret
=
{}
cocoDt
=
self
.
coco
.
loadRes
(
json_file
)
cocoEval
=
COCOeval
(
self
.
coco
,
cocoDt
,
'bbox'
)
cocoEval
.
evaluate
()
cocoEval
.
accumulate
()
cocoEval
.
summarize
()
fields
=
[
'IoU=0.5:0.95'
,
'IoU=0.5'
,
'IoU=0.75'
,
'small'
,
'medium'
,
'large'
]
for
k
in
range
(
6
):
ret
[
'mAP(bbox)/'
+
fields
[
k
]]
=
cocoEval
.
stats
[
k
]
json_obj
=
json
.
load
(
open
(
json_file
))
if
len
(
json_obj
)
>
0
and
'segmentation'
in
json_obj
[
0
]:
cocoEval
=
COCOeval
(
self
.
coco
,
cocoDt
,
'segm'
)
cocoEval
.
evaluate
()
cocoEval
.
accumulate
()
cocoEval
.
summarize
()
for
k
in
range
(
6
):
ret
[
'mAP(segm)/'
+
fields
[
k
]]
=
cocoEval
.
stats
[
k
]
return
ret
def
load
(
self
,
add_gt
=
True
,
add_mask
=
False
):
"""
Args:
add_gt: whether to add ground truth bounding box annotations to the dicts
add_mask: whether to also add ground truth mask
Returns:
a list of dict, each has keys including:
'image_id', 'file_name',
and (if add_gt is True) 'boxes', 'class', 'is_crowd', and optionally
'segmentation'.
"""
if
add_mask
:
assert
add_gt
with
timed_operation
(
'Load Groundtruth Boxes for {}'
.
format
(
self
.
name
)):
img_ids
=
self
.
coco
.
getImgIds
()
img_ids
.
sort
()
# list of dict, each has keys: height,width,id,file_name
imgs
=
self
.
coco
.
loadImgs
(
img_ids
)
for
img
in
tqdm
.
tqdm
(
imgs
):
img
[
'image_id'
]
=
img
.
pop
(
'id'
)
self
.
_use_absolute_file_name
(
img
)
if
add_gt
:
self
.
_add_detection_gt
(
img
,
add_mask
)
return
imgs
def
_use_absolute_file_name
(
self
,
img
):
"""
Change relative filename to abosolute file name.
"""
img
[
'file_name'
]
=
os
.
path
.
join
(
self
.
_imgdir
,
img
[
'file_name'
])
assert
os
.
path
.
isfile
(
img
[
'file_name'
]),
img
[
'file_name'
]
def
_add_detection_gt
(
self
,
img
,
add_mask
):
"""
Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection.
If add_mask is True, also add 'segmentation' in coco poly format.
"""
# ann_ids = self.coco.getAnnIds(imgIds=img['image_id'])
# objs = self.coco.loadAnns(ann_ids)
objs
=
self
.
coco
.
imgToAnns
[
img
[
'image_id'
]]
# equivalent but faster than the above two lines
# clean-up boxes
valid_objs
=
[]
width
=
img
.
pop
(
'width'
)
height
=
img
.
pop
(
'height'
)
for
objid
,
obj
in
enumerate
(
objs
):
if
obj
.
get
(
'ignore'
,
0
)
==
1
:
continue
x1
,
y1
,
w
,
h
=
obj
[
'bbox'
]
# bbox is originally in float
# x1/y1 means upper-left corner and w/h means true w/h. This can be verified by segmentation pixels.
# But we do make an assumption here that (0.0, 0.0) is upper-left corner of the first pixel
x1
=
np
.
clip
(
float
(
x1
),
0
,
width
)
y1
=
np
.
clip
(
float
(
y1
),
0
,
height
)
w
=
np
.
clip
(
float
(
x1
+
w
),
0
,
width
)
-
x1
h
=
np
.
clip
(
float
(
y1
+
h
),
0
,
height
)
-
y1
# Require non-zero seg area and more than 1x1 box size
if
obj
[
'area'
]
>
1
and
w
>
0
and
h
>
0
and
w
*
h
>=
4
:
obj
[
'bbox'
]
=
[
x1
,
y1
,
x1
+
w
,
y1
+
h
]
valid_objs
.
append
(
obj
)
if
add_mask
:
segs
=
obj
[
'segmentation'
]
if
not
isinstance
(
segs
,
list
):
assert
obj
[
'iscrowd'
]
==
1
obj
[
'segmentation'
]
=
None
else
:
valid_segs
=
[
np
.
asarray
(
p
)
.
reshape
(
-
1
,
2
)
.
astype
(
'float32'
)
for
p
in
segs
if
len
(
p
)
>=
6
]
if
len
(
valid_segs
)
==
0
:
logger
.
error
(
"Object {} in image {} has no valid polygons!"
.
format
(
objid
,
img
[
'file_name'
]))
elif
len
(
valid_segs
)
<
len
(
segs
):
logger
.
warn
(
"Object {} in image {} has invalid polygons!"
.
format
(
objid
,
img
[
'file_name'
]))
obj
[
'segmentation'
]
=
valid_segs
# all geometrically-valid boxes are returned
boxes
=
np
.
asarray
([
obj
[
'bbox'
]
for
obj
in
valid_objs
],
dtype
=
'float32'
)
# (n, 4)
cls
=
np
.
asarray
([
self
.
COCO_id_to_category_id
[
obj
[
'category_id'
]]
for
obj
in
valid_objs
],
dtype
=
'int32'
)
# (n,)
is_crowd
=
np
.
asarray
([
obj
[
'iscrowd'
]
for
obj
in
valid_objs
],
dtype
=
'int8'
)
# add the keys
img
[
'boxes'
]
=
boxes
# nx4
img
[
'class'
]
=
cls
# n, always >0
img
[
'is_crowd'
]
=
is_crowd
# n,
if
add_mask
:
# also required to be float32
img
[
'segmentation'
]
=
[
obj
[
'segmentation'
]
for
obj
in
valid_objs
]
@
staticmethod
def
load_many
(
basedir
,
names
,
add_gt
=
True
,
add_mask
=
False
):
"""
Load and merges several instance files together.
Returns the same format as :meth:`COCODetection.load`.
"""
if
not
isinstance
(
names
,
(
list
,
tuple
)):
names
=
[
names
]
ret
=
[]
for
n
in
names
:
coco
=
COCODetection
(
basedir
,
n
)
ret
.
extend
(
coco
.
load
(
add_gt
,
add_mask
=
add_mask
))
return
ret
class
DetectionDataset
:
"""
A singleton to load datasets, evaluate results, and provide metadata.
To use your own dataset that's not in COCO format, rewrite all methods of this class.
To use your own dataset that's not in COCO format, write a subclass that
implements the interfaces.
"""
def
__init__
(
self
):
"""
This function is responsible for setting the dataset-specific
attributes in both cfg and self.
"""
self
.
num_category
=
cfg
.
DATA
.
NUM_CATEGORY
=
len
(
COCODetection
.
class_names
)
self
.
num_classes
=
self
.
num_category
+
1
self
.
class_names
=
cfg
.
DATA
.
CLASS_NAMES
=
[
"BG"
]
+
COCODetection
.
class_names
def
load_training_roidbs
(
self
,
names
):
def
training_roidbs
(
self
):
"""
Args:
names (list[str]): name of the training datasets, e.g. ['train2014', 'valminusminival2014']
Returns:
roidbs (list[dict]):
...
...
@@ -225,14 +31,10 @@ class DetectionDataset:
Include this field only if training Mask R-CNN.
"""
return
COCODetection
.
load_many
(
cfg
.
DATA
.
BASEDIR
,
names
,
add_gt
=
True
,
add_mask
=
cfg
.
MODE_MASK
)
raise
NotImplementedError
()
def
load_inference_roidbs
(
self
,
name
):
def
inference_roidbs
(
self
):
"""
Args:
name (str): name of one inference dataset, e.g. 'minival2014'
Returns:
roidbs (list[dict]):
...
...
@@ -242,56 +44,48 @@ class DetectionDataset:
file_name (str): full path to the image
image_id (str): an id for the image. The inference results will be stored with this id.
"""
r
eturn
COCODetection
.
load_many
(
cfg
.
DATA
.
BASEDIR
,
name
,
add_gt
=
False
)
r
aise
NotImplementedError
(
)
def
eval_
or_save_inference_results
(
self
,
results
,
dataset
,
output
=
None
):
def
eval_
inference_results
(
self
,
results
,
output
=
None
):
"""
Args:
results (list[dict]): the inference results as dicts.
Each dict corresponds to one __instance__. It contains the following keys:
image_id (str): the id that matches `
load_
inference_roidbs`.
image_id (str): the id that matches `inference_roidbs`.
category_id (int): the category prediction, in range [1, #category]
bbox (list[float]): x1, y1, x2, y2
score (float):
segmentation: the segmentation mask in COCO's rle format.
dataset (str): the name of the dataset to evaluate.
output (str): the output file to optionally save the results to.
output (str): the output file or directory to optionally save the results to.
Returns:
dict: the evaluation results.
"""
continuous_id_to_COCO_id
=
{
v
:
k
for
k
,
v
in
COCODetection
.
COCO_id_to_category_id
.
items
()}
for
res
in
results
:
# convert to COCO's incontinuous category id
res
[
'category_id'
]
=
continuous_id_to_COCO_id
[
res
[
'category_id'
]]
# COCO expects results in xywh format
box
=
res
[
'bbox'
]
box
[
2
]
-=
box
[
0
]
box
[
3
]
-=
box
[
1
]
res
[
'bbox'
]
=
[
round
(
float
(
x
),
3
)
for
x
in
box
]
raise
NotImplementedError
()
assert
output
is
not
None
,
"COCO evaluation requires an output file!"
with
open
(
output
,
'w'
)
as
f
:
json
.
dump
(
results
,
f
)
if
len
(
results
):
# sometimes may crash if the results are empty?
return
COCODetection
(
cfg
.
DATA
.
BASEDIR
,
dataset
)
.
print_coco_metrics
(
output
)
else
:
return
{}
# code for singleton
:
_
instance
=
None
class
DatasetRegistry
()
:
_
registry
=
{}
def
__new__
(
cls
):
if
not
isinstance
(
cls
.
_instance
,
cls
):
cls
.
_instance
=
object
.
__new__
(
cls
)
return
cls
.
_instance
@
staticmethod
def
register
(
name
,
func
):
"""
Args:
name (str): the name of the dataset split, e.g. "coco_train2017"
func: a function which returns an instance of `DatasetSplit`
"""
assert
name
not
in
DatasetRegistry
.
_registry
,
"Dataset {} was registered already!"
.
format
(
name
)
DatasetRegistry
.
_registry
[
name
]
=
func
@
staticmethod
def
get
(
name
):
"""
Args:
name (str): the name of the dataset split, e.g. "coco_train2017"
if
__name__
==
'__main__'
:
cfg
.
DATA
.
BASEDIR
=
'~/data/coco'
c
=
COCODetection
(
cfg
.
DATA
.
BASEDIR
,
'train2014'
)
roidb
=
c
.
load
(
add_gt
=
True
,
add_mask
=
Tru
e
)
print
(
"#Images:"
,
len
(
roidb
)
)
Returns
:
DatasetSplit
"""
assert
name
in
DatasetRegistry
.
_registry
,
"Dataset {} was not egistered!"
.
format
(
nam
e
)
return
DatasetRegistry
.
_registry
[
name
](
)
examples/FasterRCNN/eval.py
View file @
9b1b5f29
...
...
@@ -21,7 +21,7 @@ from tensorpack.utils.utils import get_tqdm
from
common
import
CustomResize
,
clip_boxes
from
data
import
get_eval_dataflow
from
dataset
import
D
etectionDataset
from
dataset
import
D
atasetRegistry
from
config
import
config
as
cfg
try
:
...
...
@@ -116,7 +116,7 @@ def predict_dataflow(df, model_func, tqdm_bar=None):
Returns:
list of dict, in the format used by
`D
etectionDataset.eval_or_save
_inference_results`
`D
atasetSplit.eval
_inference_results`
"""
df
.
reset_state
()
all_results
=
[]
...
...
@@ -156,7 +156,7 @@ def multithread_predict_dataflow(dataflows, model_funcs):
Returns:
list of dict, in the format used by
`D
etectionDataset.eval_or_save
_inference_results`
`D
atasetSplit.eval
_inference_results`
"""
num_worker
=
len
(
model_funcs
)
assert
len
(
dataflows
)
==
num_worker
...
...
@@ -248,8 +248,8 @@ class EvalCallback(Callback):
output_file
=
os
.
path
.
join
(
logdir
,
'{}-outputs{}.json'
.
format
(
self
.
_eval_dataset
,
self
.
global_step
))
scores
=
D
etectionDataset
()
.
eval_or_save
_inference_results
(
all_results
,
self
.
_eval_dataset
,
output_file
)
scores
=
D
atasetRegistry
.
get
(
self
.
_eval_dataset
)
.
eval
_inference_results
(
all_results
,
output_file
)
for
k
,
v
in
scores
.
items
():
self
.
trainer
.
monitors
.
put_scalar
(
self
.
_eval_dataset
+
'-'
+
k
,
v
)
...
...
examples/FasterRCNN/model_frcnn.py
View file @
9b1b5f29
...
...
@@ -111,7 +111,7 @@ def fastrcnn_outputs(feature, num_categories, class_agnostic_regression=False):
Returns:
cls_logits: N x num_class classification logits
reg_logits: N x num_classx4 or Nx
2
x4 if class agnostic
reg_logits: N x num_classx4 or Nx
1
x4 if class agnostic
"""
num_classes
=
num_categories
+
1
classification
=
FullyConnected
(
...
...
examples/FasterRCNN/train.py
View file @
9b1b5f29
...
...
@@ -19,7 +19,8 @@ from tensorpack.tfutils import collect_env_info
from
tensorpack.tfutils.common
import
get_tf_version_tuple
from
generalized_rcnn
import
ResNetFPNModel
,
ResNetC4Model
from
dataset
import
DetectionDataset
from
dataset
import
DatasetRegistry
from
coco
import
register_coco
from
config
import
finalize_configs
,
config
as
cfg
from
data
import
get_eval_dataflow
,
get_train_dataflow
from
eval
import
DetectionResult
,
predict_image
,
multithread_predict_dataflow
,
EvalCallback
...
...
@@ -95,7 +96,7 @@ def do_evaluate(pred_config, output_file):
for
k
in
range
(
num_gpu
)]
all_results
=
multithread_predict_dataflow
(
dataflows
,
graph_funcs
)
output
=
output_file
+
'-'
+
dataset
D
etectionDataset
()
.
eval_or_save_inference_results
(
all_results
,
dataset
,
output
)
D
atasetRegistry
.
get
(
dataset
)
.
eval_inference_results
(
all_results
,
output
)
def
do_predict
(
pred_func
,
input_file
):
...
...
@@ -127,9 +128,9 @@ if __name__ == '__main__':
args
=
parser
.
parse_args
()
if
args
.
config
:
cfg
.
update_args
(
args
.
config
)
register_coco
(
cfg
.
DATA
.
BASEDIR
)
# add COCO datasets to the registry
MODEL
=
ResNetFPNModel
()
if
cfg
.
MODE_FPN
else
ResNetC4Model
()
DetectionDataset
()
# initialize the config with information from our dataset
if
args
.
visualize
or
args
.
evaluate
or
args
.
predict
:
if
not
tf
.
test
.
is_gpu_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment