Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1a12ccd1
Commit
1a12ccd1
authored
Nov 26, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add Mask-RCNN implementation
parent
2e490884
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
253 additions
and
66 deletions
+253
-66
README.md
README.md
+1
-1
examples/FasterRCNN/NOTES.md
examples/FasterRCNN/NOTES.md
+3
-1
examples/FasterRCNN/README.md
examples/FasterRCNN/README.md
+21
-15
examples/FasterRCNN/basemodel.py
examples/FasterRCNN/basemodel.py
+2
-0
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+3
-1
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+5
-3
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+67
-12
examples/FasterRCNN/model.py
examples/FasterRCNN/model.py
+67
-9
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+74
-19
examples/FasterRCNN/viz.py
examples/FasterRCNN/viz.py
+9
-4
examples/README.md
examples/README.md
+1
-1
No files found.
README.md
View file @
1a12ccd1
...
...
@@ -10,7 +10,7 @@ See some [examples](examples) to learn about the framework. Everything runs on m
### Vision:
+
[
Train ResNet/SE-ResNet on ImageNet
](
examples/ResNet
)
+
[
Train Faster-RCNN on COCO object detection
](
examples/FasterRCNN
)
+
[
Train Faster-RCNN
/ Mask-RCNN
on COCO object detection
](
examples/FasterRCNN
)
+
[
Generative Adversarial Network(GAN) variants
](
examples/GAN
)
, including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN.
+
[
DoReFa-Net: train binary / low-bitwidth CNN on ImageNet
](
examples/DoReFa-Net
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
examples/HED
)
...
...
examples/FasterRCNN/NOTES.md
View file @
1a12ccd1
...
...
@@ -30,7 +30,7 @@ Model:
2.
We use ROIAlign, and because of (1),
`tf.image.crop_and_resize`
is __NOT__ ROIAlign.
3.
We only support single image per GPU
for now
.
3.
We only support single image per GPU.
4.
Because of (3), BatchNorm statistics are not supposed to be updated during fine-tuning.
This specific kind of BatchNorm will need
[
my kernel
](
https://github.com/tensorflow/tensorflow/pull/12580
)
...
...
@@ -45,3 +45,5 @@ Speed:
a slow convolution algorithm, or you spend more time on autotune.
This is a general problem of TensorFlow when running against variable-sized input.
3.
With a large roi batch size (e.g. >= 256), GPU utilitization should stay around 90%.
examples/FasterRCNN/README.md
View file @
1a12ccd1
# Faster-RCNN on COCO
This example aims to provide a minimal (1.2k lines) multi-GPU implementation of ResNet-Faster-RCNN on COCO.
# Faster-RCNN / Mask-RCNN on COCO
This example aims to provide a minimal (1.3k lines) multi-GPU implementation of
Faster-RCNN / Mask-RCNN (without FPN) on COCO.
## Dependencies
+
TensorFlow >= 1.4.0
+
Python 3;
TensorFlow >= 1.4.0
+
Install
[
pycocotools
](
https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools
)
, OpenCV.
+
Pre-trained
[
ResNet
50
model
](
https://goo.gl/6XjK9V
)
from tensorpack model zoo.
+
Pre-trained
[
ResNet model
](
https://goo.gl/6XjK9V
)
from tensorpack model zoo.
+
COCO data. It assumes the following directory structure:
```
DIR/
...
...
@@ -23,36 +24,41 @@ DIR/
## Usage
Change
`BASEDIR`
in
`config.py`
to
`/path/to/DIR`
as described above.
Change config:
1.
Set
`BASEDIR`
in
`config.py`
to
`/path/to/DIR`
as described above.
2.
Set
`MODE_MASK`
to switch Faster-RCNN or Mask-RCNN.
T
o t
rain:
Train:
```
./train.py --load /path/to/ImageNet-ResNet50.npz
```
The code is only for training with 1, 2, 4 or 8 GPUs.
Otherwise, you probably need different hyperparameters for the same performance.
To p
redict on an image (and show output in a window):
P
redict on an image (and show output in a window):
```
./train.py --predict input.jpg --load /path/to/model
```
To evaluate the performance (pretrained models can be downloaded in
[
model zoo
](
http://models.tensorpack.com/FasterRCNN
)
:
Evaluate the performance of a model and save to json.
(A pretrained model can be downloaded in
[
model zoo
](
http://models.tensorpack.com/FasterRCNN
)
:
```
./train.py --evaluate output.json --load /path/to/model
```
## Results
Trained on trainval35k and evaluated on minival, got the following results:
mAP@IoU=0.50:0.95:
Models are trained on trainval35k and evaluated on minival using mAP@IoU=0.50:0.95.
MaskRCNN results contain both bbox and segm mAP.
|Backbone |
`FASTRCNN_BATCH`
| mAP | Time |
| - | - | - | - |
| Res50 | 256 | 34.4 | 49h on 8 TitanX |
| Res50 | 64 | 33.0 | 22h on 8 P100 |
|Backbone |
`FASTRCNN_BATCH`
| resolution | mAP (bbox/segm) | Time |
| - | - | - | - | - |
| Res50 | 64 | (600, 1024) | 33.0 | 22h on 8 P100 |
| Res50 | 256 | (600, 1024) | 34.4 | 49h on 8 TitanX |
| Res50 | 512 | (800, 1333) | 35.6 | 55h on 8 P100|
| Res50 | 512 | (800, 1333) | 36.9/32.3 | 59h on 8 P100|
The hyperparameters are not carefully tuned. You can probably get better performance by e.g. training long
er.
Note that these models are trained with a longer learning schedule than the pap
er.
## Notes
...
...
examples/FasterRCNN/basemodel.py
View file @
1a12ccd1
...
...
@@ -4,6 +4,7 @@
import
tensorflow
as
tf
from
tensorpack.tfutils.argscope
import
argscope
,
get_arg_scope
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.models
import
(
Conv2D
,
MaxPooling
,
BatchNorm
,
BNReLU
)
...
...
@@ -88,6 +89,7 @@ def pretrained_resnet_conv4(image, num_blocks):
return
l
@
auto_reuse_variable_scope
def
resnet_conv5
(
image
,
num_block
):
with
argscope
([
Conv2D
,
BatchNorm
],
data_format
=
'NCHW'
),
\
argscope
(
Conv2D
,
nl
=
tf
.
identity
,
use_bias
=
False
),
\
...
...
examples/FasterRCNN/config.py
View file @
1a12ccd1
...
...
@@ -4,6 +4,9 @@
import
numpy
as
np
# mode flags ---------------------
MODE_MASK
=
False
# dataset -----------------------
BASEDIR
=
'/path/to/your/COCO/DIR'
TRAIN_DATASET
=
[
'train2014'
,
'valminusminival2014'
]
...
...
@@ -38,7 +41,6 @@ RPN_MIN_SIZE = 0
RPN_PROPOSAL_NMS_THRESH
=
0.7
TRAIN_PRE_NMS_TOPK
=
12000
TRAIN_POST_NMS_TOPK
=
2000
# boxes overlapping crowd will be ignored.
CROWD_OVERLAP_THRES
=
0.7
...
...
examples/FasterRCNN/data.py
View file @
1a12ccd1
...
...
@@ -4,6 +4,7 @@
import
cv2
import
numpy
as
np
import
copy
from
tensorpack.utils.argtools
import
memoized
,
log_once
from
tensorpack.dataflow
import
(
...
...
@@ -231,8 +232,9 @@ def get_train_dataflow(add_mask=False):
ret
=
[
im
,
fm_labels
,
fm_boxes
,
boxes
,
klass
]
# masks
segmentation
=
img
.
get
(
'segmentation'
,
None
)
if
segmentation
is
not
None
:
if
add_mask
:
# augmentation will modify the polys in-place
segmentation
=
copy
.
deepcopy
(
img
.
get
(
'segmentation'
,
None
))
segmentation
=
[
segmentation
[
k
]
for
k
in
range
(
len
(
segmentation
))
if
not
is_crowd
[
k
]]
assert
len
(
segmentation
)
==
len
(
boxes
)
...
...
@@ -266,7 +268,7 @@ def get_eval_dataflow():
assert
im
is
not
None
,
fname
return
im
ds
=
MapDataComponent
(
ds
,
f
,
0
)
#
ds = PrefetchDataZMQ(ds, 1)
ds
=
PrefetchDataZMQ
(
ds
,
1
)
return
ds
...
...
examples/FasterRCNN/eval.py
View file @
1a12ccd1
...
...
@@ -5,11 +5,14 @@
import
tqdm
import
os
from
collections
import
namedtuple
import
numpy
as
np
import
cv2
from
tensorpack.utils.utils
import
get_tqdm_kwargs
from
pycocotools.coco
import
COCO
from
pycocotools.cocoeval
import
COCOeval
import
pycocotools.mask
as
cocomask
from
coco
import
COCOMeta
from
common
import
CustomResize
...
...
@@ -17,14 +20,41 @@ import config
DetectionResult
=
namedtuple
(
'DetectionResult'
,
[
'
class_id'
,
'box'
,
'score
'
])
[
'
box'
,
'score'
,
'class_id'
,
'mask
'
])
"""
class_id: int, 1~NUM_CLASS
box: 4 float
score: float
class_id: int, 1~NUM_CLASS
mask: None, or a binary image of the original image shape
"""
def
fill_full_mask
(
box
,
mask
,
shape
):
"""
Args:
box: 4 float
mask: MxM floats
shape: h,w
"""
# int() is floor
# box fpcoor=0.0 -> intcoor=0.0
x0
,
y0
=
list
(
map
(
int
,
box
[:
2
]
+
0.5
))
# box fpcoor=h -> intcoor=h-1, inclusive
x1
,
y1
=
list
(
map
(
int
,
box
[
2
:]
-
0.5
))
# inclusive
x1
=
max
(
x0
,
x1
)
# require at least 1x1
y1
=
max
(
y0
,
y1
)
w
=
x1
+
1
-
x0
h
=
y1
+
1
-
y0
# rounding errors could happen here, because masks were not originally computed for this shape.
# but it's hard to do better, because the network does not know the "original" scale
mask
=
(
cv2
.
resize
(
mask
,
(
w
,
h
))
>
0.5
)
.
astype
(
'uint8'
)
ret
=
np
.
zeros
(
shape
,
dtype
=
'uint8'
)
ret
[
y0
:
y1
+
1
,
x0
:
x1
+
1
]
=
mask
return
ret
def
detect_one_image
(
img
,
model_func
):
"""
Run detection on one image, using the TF callable.
...
...
@@ -32,19 +62,30 @@ def detect_one_image(img, model_func):
Args:
img: an image
model_func: a callable from TF model, takes [image] and returns (probs, boxes)
model_func: a callable from TF model,
takes image and returns (boxes, probs, labels, [masks])
Returns:
[DetectionResult]
"""
orig_shape
=
img
.
shape
[:
2
]
resizer
=
CustomResize
(
config
.
SHORT_EDGE_SIZE
,
config
.
MAX_SIZE
)
resized_img
=
resizer
.
augment
(
img
)
scale
=
(
resized_img
.
shape
[
0
]
*
1.0
/
img
.
shape
[
0
]
+
resized_img
.
shape
[
1
]
*
1.0
/
img
.
shape
[
1
])
/
2
boxes
,
probs
,
labels
=
model_func
(
resized_img
)
boxes
,
probs
,
labels
,
*
masks
=
model_func
(
resized_img
)
boxes
=
boxes
/
scale
results
=
[
DetectionResult
(
*
args
)
for
args
in
zip
(
labels
,
boxes
,
probs
)]
if
masks
:
# has mask
full_masks
=
[
fill_full_mask
(
box
,
mask
,
orig_shape
)
for
box
,
mask
in
zip
(
boxes
,
masks
[
0
])]
masks
=
full_masks
else
:
# fill with none
masks
=
[
None
]
*
len
(
boxes
)
results
=
[
DetectionResult
(
*
args
)
for
args
in
zip
(
boxes
,
probs
,
labels
,
masks
)]
return
results
...
...
@@ -62,16 +103,26 @@ def eval_on_dataflow(df, detect_func):
with
tqdm
.
tqdm
(
total
=
df
.
size
(),
**
get_tqdm_kwargs
())
as
pbar
:
for
img
,
img_id
in
df
.
get_data
():
results
=
detect_func
(
img
)
for
classid
,
box
,
score
in
results
:
cat_id
=
COCOMeta
.
class_id_to_category_id
[
classid
]
for
r
in
results
:
box
=
r
.
box
cat_id
=
COCOMeta
.
class_id_to_category_id
[
r
.
class_id
]
box
[
2
]
-=
box
[
0
]
box
[
3
]
-=
box
[
1
]
all_results
.
append
({
res
=
{
'image_id'
:
img_id
,
'category_id'
:
cat_id
,
'bbox'
:
list
(
map
(
lambda
x
:
float
(
round
(
x
,
1
)),
box
)),
'score'
:
float
(
round
(
score
,
2
)),
})
'score'
:
float
(
round
(
r
.
score
,
2
)),
}
# also append segmentation to results
if
r
.
mask
is
not
None
:
rle
=
cocomask
.
encode
(
np
.
array
(
r
.
mask
[:,
:,
None
],
order
=
'F'
))[
0
]
rle
[
'counts'
]
=
rle
[
'counts'
]
.
decode
(
'ascii'
)
res
[
'segmentation'
]
=
rle
all_results
.
append
(
res
)
pbar
.
update
(
1
)
return
all_results
...
...
@@ -84,9 +135,13 @@ def print_evaluation_scores(json_file):
'instances_{}.json'
.
format
(
config
.
VAL_DATASET
))
coco
=
COCO
(
annofile
)
cocoDt
=
coco
.
loadRes
(
json_file
)
imgIds
=
sorted
(
coco
.
getImgIds
())
cocoEval
=
COCOeval
(
coco
,
cocoDt
,
'bbox'
)
cocoEval
.
params
.
imgIds
=
imgIds
cocoEval
.
evaluate
()
cocoEval
.
accumulate
()
cocoEval
.
summarize
()
if
config
.
MODE_MASK
:
cocoEval
=
COCOeval
(
coco
,
cocoDt
,
'segm'
)
cocoEval
.
evaluate
()
cocoEval
.
accumulate
()
cocoEval
.
summarize
()
examples/FasterRCNN/model.py
View file @
1a12ccd1
...
...
@@ -8,7 +8,7 @@ from tensorpack.tfutils.summary import add_moving_summary
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.models
import
(
Conv2D
,
FullyConnected
,
GlobalAvgPooling
,
layer_register
)
Conv2D
,
FullyConnected
,
GlobalAvgPooling
,
layer_register
,
Deconv2D
)
from
utils.box_ops
import
pairwise_iou
import
config
...
...
@@ -90,6 +90,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
precision
=
tf
.
to_float
(
tf
.
truediv
(
pos_prediction_corr
,
nr_pos_prediction
))
precision
=
tf
.
where
(
tf
.
equal
(
nr_pos_prediction
,
0
),
0.0
,
precision
,
name
=
'precision_th{}'
.
format
(
th
))
summaries
.
append
(
precision
)
add_moving_summary
(
*
summaries
)
label_loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
tf
.
to_float
(
valid_anchor_labels
),
logits
=
valid_label_logits
)
...
...
@@ -105,7 +106,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
box_loss
,
tf
.
cast
(
nr_valid
,
tf
.
float32
),
name
=
'box_loss'
)
add_moving_summary
(
*
([
label_loss
,
box_loss
,
nr_valid
,
nr_pos
]
+
summaries
)
)
add_moving_summary
(
label_loss
,
box_loss
,
nr_valid
,
nr_pos
)
return
label_loss
,
box_loss
...
...
@@ -126,8 +127,8 @@ def decode_bbox_target(box_predictions, anchors):
anchors_x1y1x2y2
=
tf
.
reshape
(
anchors
,
(
-
1
,
2
,
2
))
anchors_x1y1
,
anchors_x2y2
=
tf
.
split
(
anchors_x1y1x2y2
,
2
,
axis
=
1
)
waha
=
tf
.
to_float
(
anchors_x2y2
-
anchors_x1y1
)
xaya
=
tf
.
to_float
(
anchors_x2y2
+
anchors_x1y1
)
*
0.5
waha
=
anchors_x2y2
-
anchors_x1y1
xaya
=
(
anchors_x2y2
+
anchors_x1y1
)
*
0.5
wbhb
=
tf
.
exp
(
tf
.
minimum
(
box_pred_twth
,
config
.
BBOX_DECODE_CLIP
))
*
waha
...
...
@@ -150,16 +151,15 @@ def encode_bbox_target(boxes, anchors):
"""
anchors_x1y1x2y2
=
tf
.
reshape
(
anchors
,
(
-
1
,
2
,
2
))
anchors_x1y1
,
anchors_x2y2
=
tf
.
split
(
anchors_x1y1x2y2
,
2
,
axis
=
1
)
waha
=
tf
.
to_float
(
anchors_x2y2
-
anchors_x1y1
)
xaya
=
tf
.
to_float
(
anchors_x2y2
+
anchors_x1y1
)
*
0.5
waha
=
anchors_x2y2
-
anchors_x1y1
xaya
=
(
anchors_x2y2
+
anchors_x1y1
)
*
0.5
boxes_x1y1x2y2
=
tf
.
reshape
(
boxes
,
(
-
1
,
2
,
2
))
boxes_x1y1
,
boxes_x2y2
=
tf
.
split
(
boxes_x1y1x2y2
,
2
,
axis
=
1
)
wbhb
=
tf
.
to_float
(
boxes_x2y2
-
boxes_x1y1
)
xbyb
=
tf
.
to_float
(
boxes_x2y2
+
boxes_x1y1
)
*
0.5
wbhb
=
boxes_x2y2
-
boxes_x1y1
xbyb
=
(
boxes_x2y2
+
boxes_x1y1
)
*
0.5
# Note that here not all boxes are valid. Some may be zero
txty
=
(
xbyb
-
xaya
)
/
waha
twth
=
tf
.
log
(
wbhb
/
waha
)
# may contain -inf for invalid boxes
encoded
=
tf
.
concat
([
txty
,
twth
],
axis
=
1
)
# (-1x2x2)
...
...
@@ -292,6 +292,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
ret_labels
=
tf
.
concat
(
[
tf
.
gather
(
gt_labels
,
fg_inds_wrt_gt
),
tf
.
zeros_like
(
bg_inds
,
dtype
=
tf
.
int64
)],
axis
=
0
,
name
=
'sampled_labels'
)
# stop the gradient -- they are meant to be ground-truth
return
tf
.
stop_gradient
(
ret_boxes
),
tf
.
stop_gradient
(
ret_labels
),
fg_inds_wrt_gt
...
...
@@ -487,3 +488,60 @@ def fastrcnn_predictions(boxes, probs):
filtered_selection
=
tf
.
gather
(
selected_indices
,
topk_indices
)
filtered_selection
=
tf
.
reverse
(
filtered_selection
,
axis
=
[
1
],
name
=
'filtered_indices'
)
return
filtered_selection
,
topk_probs
@
layer_register
(
log_shape
=
True
)
def
maskrcnn_head
(
feature
,
num_class
):
"""
Args:
feature (NxCx7x7):
num_classes(int): num_category + 1
Returns:
mask_logits (N x num_category x 14 x 14):
"""
with
argscope
([
Conv2D
,
Deconv2D
],
data_format
=
'NCHW'
,
W_init
=
tf
.
variance_scaling_initializer
(
scale
=
2.0
,
mode
=
'fan_in'
,
distribution
=
'normal'
)):
l
=
Deconv2D
(
'deconv'
,
feature
,
256
,
2
,
stride
=
2
,
nl
=
tf
.
nn
.
relu
)
l
=
Conv2D
(
'conv'
,
l
,
num_class
-
1
,
1
)
return
l
@
under_name_scope
()
def
maskrcnn_loss
(
mask_logits
,
fg_labels
,
fg_target_masks
):
"""
Args:
mask_logits: #fg x #category x14x14
fg_labels: #fg, in 1~#class
fg_target_masks: #fgx14x14, int
"""
num_fg
=
tf
.
size
(
fg_labels
)
indices
=
tf
.
stack
([
tf
.
range
(
num_fg
),
tf
.
to_int32
(
fg_labels
)
-
1
],
axis
=
1
)
# #fgx2
mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #fgx14x14
mask_probs
=
tf
.
sigmoid
(
mask_logits
)
# add some training visualizations to tensorboard
with
tf
.
name_scope
(
'mask_viz'
):
viz
=
tf
.
concat
([
fg_target_masks
,
mask_probs
],
axis
=
1
)
viz
=
tf
.
expand_dims
(
viz
,
3
)
viz
=
tf
.
cast
(
viz
*
255
,
tf
.
uint8
,
name
=
'viz'
)
tf
.
summary
.
image
(
'mask_truth|pred'
,
viz
,
max_outputs
=
10
)
loss
=
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
labels
=
fg_target_masks
,
logits
=
mask_logits
)
loss
=
tf
.
reduce_mean
(
loss
,
name
=
'maskrcnn_loss'
)
pred_label
=
mask_probs
>
0.5
truth_label
=
fg_target_masks
>
0.5
accuracy
=
tf
.
reduce_mean
(
tf
.
to_float
(
tf
.
equal
(
pred_label
,
truth_label
)),
name
=
'accuracy'
)
pos_accuracy
=
tf
.
logical_and
(
tf
.
equal
(
pred_label
,
truth_label
),
tf
.
equal
(
truth_label
,
True
))
pos_accuracy
=
tf
.
reduce_mean
(
tf
.
to_float
(
pos_accuracy
),
name
=
'pos_accuracy'
)
fg_pixel_ratio
=
tf
.
reduce_mean
(
tf
.
to_float
(
truth_label
),
name
=
'fg_pixel_ratio'
)
add_moving_summary
(
loss
,
accuracy
,
fg_pixel_ratio
,
pos_accuracy
)
return
loss
examples/FasterRCNN/train.py
View file @
1a12ccd1
...
...
@@ -28,7 +28,8 @@ from model import (
clip_boxes
,
decode_bbox_target
,
encode_bbox_target
,
crop_and_resize
,
rpn_head
,
rpn_losses
,
generate_rpn_proposals
,
sample_fast_rcnn_targets
,
roi_align
,
fastrcnn_head
,
fastrcnn_losses
,
fastrcnn_predictions
)
fastrcnn_head
,
fastrcnn_losses
,
fastrcnn_predictions
,
maskrcnn_head
,
maskrcnn_loss
)
from
data
import
(
get_train_dataflow
,
get_eval_dataflow
,
get_all_anchors
)
...
...
@@ -47,15 +48,26 @@ def get_batch_factor():
return
8
//
nr_gpu
def
get_model_output_names
():
ret
=
[
'final_boxes'
,
'final_probs'
,
'final_labels'
]
if
config
.
MODE_MASK
:
ret
.
append
(
'final_masks'
)
return
ret
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
ret
urn
[
ret
=
[
InputDesc
(
tf
.
float32
,
(
None
,
None
,
3
),
'image'
),
InputDesc
(
tf
.
int32
,
(
None
,
None
,
config
.
NUM_ANCHOR
),
'anchor_labels'
),
InputDesc
(
tf
.
float32
,
(
None
,
None
,
config
.
NUM_ANCHOR
,
4
),
'anchor_boxes'
),
InputDesc
(
tf
.
float32
,
(
None
,
4
),
'gt_boxes'
),
InputDesc
(
tf
.
int64
,
(
None
,),
'gt_labels'
),
# all > 0
]
InputDesc
(
tf
.
int64
,
(
None
,),
'gt_labels'
)]
# all > 0
if
config
.
MODE_MASK
:
ret
.
append
(
InputDesc
(
tf
.
uint8
,
(
None
,
None
,
None
),
'gt_masks'
)
)
# NR_GT x height x width
return
ret
def
_preprocess
(
self
,
image
):
image
=
tf
.
expand_dims
(
image
,
0
)
...
...
@@ -79,7 +91,10 @@ class Model(ModelDesc):
def
_build_graph
(
self
,
inputs
):
is_training
=
get_current_tower_context
()
.
is_training
image
,
anchor_labels
,
anchor_boxes
,
gt_boxes
,
gt_labels
=
inputs
if
config
.
MODE_MASK
:
image
,
anchor_labels
,
anchor_boxes
,
gt_boxes
,
gt_labels
,
gt_masks
=
inputs
else
:
image
,
anchor_labels
,
anchor_boxes
,
gt_boxes
,
gt_labels
=
inputs
fm_anchors
=
self
.
_get_anchors
(
image
)
image
=
self
.
_preprocess
(
image
)
# 1CHW
image_shape2d
=
tf
.
shape
(
image
)[
2
:]
...
...
@@ -104,8 +119,19 @@ class Model(ModelDesc):
boxes_on_featuremap
=
proposal_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
)
roi_resized
=
roi_align
(
featuremap
,
boxes_on_featuremap
,
14
)
feature_fastrcnn
=
resnet_conv5
(
roi_resized
,
config
.
RESNET_NUM_BLOCK
[
-
1
])
# nxcx7x7
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
'fastrcnn'
,
feature_fastrcnn
,
config
.
NUM_CLASS
)
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
def
ff_true
():
feature_fastrcnn
=
resnet_conv5
(
roi_resized
,
config
.
RESNET_NUM_BLOCK
[
-
1
])
# nxcx7x7
fastrcnn_label_logits
,
fastrcnn_box_logits
=
fastrcnn_head
(
'fastrcnn'
,
feature_fastrcnn
,
config
.
NUM_CLASS
)
return
feature_fastrcnn
,
fastrcnn_label_logits
,
fastrcnn_box_logits
def
ff_false
():
ncls
=
config
.
NUM_CLASS
return
tf
.
zeros
([
0
,
2048
,
7
,
7
]),
tf
.
zeros
([
0
,
ncls
]),
tf
.
zeros
([
0
,
ncls
-
1
,
4
])
feature_fastrcnn
,
fastrcnn_label_logits
,
fastrcnn_box_logits
=
tf
.
cond
(
tf
.
size
(
boxes_on_featuremap
)
>
0
,
ff_true
,
ff_false
)
if
is_training
:
# rpn loss
...
...
@@ -116,6 +142,7 @@ class Model(ModelDesc):
fg_inds_wrt_sample
=
tf
.
reshape
(
tf
.
where
(
rcnn_labels
>
0
),
[
-
1
])
# fg inds w.r.t all samples
fg_sampled_boxes
=
tf
.
gather
(
rcnn_sampled_boxes
,
fg_inds_wrt_sample
)
# TODO move to models
with
tf
.
name_scope
(
'fg_sample_patch_viz'
):
fg_sampled_patches
=
crop_and_resize
(
image
,
fg_sampled_boxes
,
...
...
@@ -132,13 +159,30 @@ class Model(ModelDesc):
encoded_boxes
,
tf
.
gather
(
fastrcnn_box_logits
,
fg_inds_wrt_sample
))
if
config
.
MODE_MASK
:
# maskrcnn loss
fg_labels
=
tf
.
gather
(
rcnn_labels
,
fg_inds_wrt_sample
)
fg_feature
=
tf
.
gather
(
feature_fastrcnn
,
fg_inds_wrt_sample
)
mask_logits
=
maskrcnn_head
(
'maskrcnn'
,
fg_feature
,
config
.
NUM_CLASS
)
# #fg x #cat x 14x14
gt_masks_for_fg
=
tf
.
gather
(
gt_masks
,
fg_inds_wrt_gt
)
# nfg x H x W
target_masks_for_fg
=
crop_and_resize
(
tf
.
expand_dims
(
gt_masks_for_fg
,
1
),
fg_sampled_boxes
,
tf
.
range
(
tf
.
size
(
fg_inds_wrt_gt
)),
14
)
# nfg x 1x14x14
target_masks_for_fg
=
tf
.
squeeze
(
target_masks_for_fg
,
1
,
'sampled_fg_mask_targets'
)
mrcnn_loss
=
maskrcnn_loss
(
mask_logits
,
fg_labels
,
target_masks_for_fg
)
else
:
mrcnn_loss
=
0.0
wd_cost
=
regularize_cost
(
'(?:group1|group2|group3|rpn|fastrcnn)/.*W'
,
'(?:group1|group2|group3|rpn|fastrcnn
|maskrcnn
)/.*W'
,
l2_regularizer
(
1e-4
),
name
=
'wd_cost'
)
self
.
cost
=
tf
.
add_n
([
rpn_label_loss
,
rpn_box_loss
,
fastrcnn_label_loss
,
fastrcnn_box_loss
,
mrcnn_loss
,
wd_cost
],
'total_cost'
)
add_moving_summary
(
self
.
cost
,
wd_cost
)
...
...
@@ -153,8 +197,22 @@ class Model(ModelDesc):
# indices: Nx2. Each index into (#proposal, #category)
pred_indices
,
final_probs
=
fastrcnn_predictions
(
decoded_boxes
,
label_probs
)
final_probs
=
tf
.
identity
(
final_probs
,
'final_probs'
)
tf
.
gather_nd
(
decoded_boxes
,
pred_indices
,
name
=
'final_boxes'
)
tf
.
add
(
pred_indices
[:,
1
],
1
,
name
=
'final_labels'
)
final_boxes
=
tf
.
gather_nd
(
decoded_boxes
,
pred_indices
,
name
=
'final_boxes'
)
final_labels
=
tf
.
add
(
pred_indices
[:,
1
],
1
,
name
=
'final_labels'
)
if
config
.
MODE_MASK
:
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
def
f1
():
roi_resized
=
roi_align
(
featuremap
,
final_boxes
*
(
1.0
/
config
.
ANCHOR_STRIDE
),
14
)
feature_maskrcnn
=
resnet_conv5
(
roi_resized
,
config
.
RESNET_NUM_BLOCK
[
-
1
])
mask_logits
=
maskrcnn_head
(
'maskrcnn'
,
feature_maskrcnn
,
config
.
NUM_CLASS
)
# #result x #cat x 14x14
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
final_labels
)),
tf
.
to_int32
(
final_labels
)
-
1
],
axis
=
1
)
final_mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #resultx14x14
return
tf
.
sigmoid
(
final_mask_logits
)
final_masks
=
tf
.
cond
(
tf
.
size
(
final_probs
)
>
0
,
f1
,
lambda
:
tf
.
zeros
([
0
,
14
,
14
]))
tf
.
identity
(
final_masks
,
name
=
'final_masks'
)
def
_get_optimizer
(
self
):
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
0.003
,
trainable
=
False
)
...
...
@@ -171,6 +229,9 @@ class Model(ModelDesc):
def
visualize
(
model_path
,
nr_visualize
=
50
,
output_dir
=
'output'
):
df
=
get_train_dataflow
()
# we don't visualize mask stuff
df
.
reset_state
()
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
session_init
=
get_model_loader
(
model_path
),
...
...
@@ -183,8 +244,6 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
'final_probs'
,
'final_labels'
,
]))
df
=
get_train_dataflow
()
df
.
reset_state
()
if
os
.
path
.
isdir
(
output_dir
):
shutil
.
rmtree
(
output_dir
)
...
...
@@ -237,7 +296,7 @@ class EvalCallback(Callback):
def
_setup_graph
(
self
):
self
.
pred
=
self
.
trainer
.
get_predictor
(
[
'image'
],
[
'final_boxes'
,
'final_probs'
,
'final_labels'
]
)
get_model_output_names
()
)
self
.
df
=
get_eval_dataflow
()
def
_before_train
(
self
):
...
...
@@ -288,11 +347,7 @@ if __name__ == '__main__':
model
=
Model
(),
session_init
=
get_model_loader
(
args
.
load
),
input_names
=
[
'image'
],
output_names
=
[
'final_boxes'
,
'final_probs'
,
'final_labels'
,
]))
output_names
=
get_model_output_names
()))
if
args
.
evaluate
:
assert
args
.
evaluate
.
endswith
(
'.json'
)
offline_evaluate
(
pred
,
args
.
evaluate
)
...
...
@@ -308,7 +363,7 @@ if __name__ == '__main__':
cfg
=
TrainConfig
(
model
=
Model
(),
data
=
QueueInput
(
get_train_dataflow
()),
data
=
QueueInput
(
get_train_dataflow
(
add_mask
=
config
.
MODE_MASK
)),
callbacks
=
[
PeriodicTrigger
(
ModelSaver
(),
every_k_epochs
=
5
),
# linear warmup
...
...
examples/FasterRCNN/viz.py
View file @
1a12ccd1
...
...
@@ -72,11 +72,16 @@ def draw_final_outputs(img, results):
return
img
tags
=
[]
for
label
,
_
,
score
in
results
:
for
r
in
results
:
tags
.
append
(
"{},{:.2f}"
.
format
(
config
.
CLASS_NAMES
[
label
],
score
))
boxes
=
np
.
asarray
([
x
.
box
for
x
in
results
])
return
viz
.
draw_boxes
(
img
,
boxes
,
tags
)
"{},{:.2f}"
.
format
(
config
.
CLASS_NAMES
[
r
.
class_id
],
r
.
score
))
boxes
=
np
.
asarray
([
r
.
box
for
r
in
results
])
ret
=
viz
.
draw_boxes
(
img
,
boxes
,
tags
)
for
r
in
results
:
if
r
.
mask
is
not
None
:
ret
=
draw_mask
(
ret
,
r
.
mask
)
return
ret
def
draw_mask
(
im
,
mask
,
alpha
=
0.5
,
color
=
None
):
...
...
examples/README.md
View file @
1a12ccd1
...
...
@@ -17,7 +17,7 @@ Without a setting and performance comparable to someone else, you don't know if
| Name | Performance |
| --- | --- |
| Train
[
ResNet
](
ResNet
)
and
[
ShuffleNet
](
ShuffleNet
)
on ImageNet | reproduce paper |
|
[
Train Faster-RCNN on COCO
](
FasterRCNN
)
| reproduce paper |
|
[
Train Faster-RCNN
/ Mask-RCNN
on COCO
](
FasterRCNN
)
| reproduce paper |
|
[
DoReFa-Net: training binary / low-bitwidth CNN on ImageNet
](
DoReFa-Net
)
| reproduce paper |
|
[
Generative Adversarial Network(GAN) variants
](
GAN
)
, including DCGAN, InfoGAN,
<br/>
Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN | visually reproduce |
|
[
Inception-BN and InceptionV3
](
Inception
)
| reproduce reference code |
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment