Commit 1a12ccd1 authored by Yuxin Wu's avatar Yuxin Wu

Add Mask-RCNN implementation

parent 2e490884
......@@ -10,7 +10,7 @@ See some [examples](examples) to learn about the framework. Everything runs on m
### Vision:
+ [Train ResNet/SE-ResNet on ImageNet](examples/ResNet)
+ [Train Faster-RCNN on COCO object detection](examples/FasterRCNN)
+ [Train Faster-RCNN / Mask-RCNN on COCO object detection](examples/FasterRCNN)
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN.
+ [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED)
......
......@@ -30,7 +30,7 @@ Model:
2. We use ROIAlign, and because of (1), `tf.image.crop_and_resize` is __NOT__ ROIAlign.
3. We only support single image per GPU for now.
3. We only support single image per GPU.
4. Because of (3), BatchNorm statistics are not supposed to be updated during fine-tuning.
This specific kind of BatchNorm will need [my kernel](https://github.com/tensorflow/tensorflow/pull/12580)
......@@ -45,3 +45,5 @@ Speed:
a slow convolution algorithm, or you spend more time on autotune.
This is a general problem of TensorFlow when running against variable-sized input.
3. With a large roi batch size (e.g. >= 256), GPU utilitization should stay around 90%.
# Faster-RCNN on COCO
This example aims to provide a minimal (1.2k lines) multi-GPU implementation of ResNet-Faster-RCNN on COCO.
# Faster-RCNN / Mask-RCNN on COCO
This example aims to provide a minimal (1.3k lines) multi-GPU implementation of
Faster-RCNN / Mask-RCNN (without FPN) on COCO.
## Dependencies
+ TensorFlow >= 1.4.0
+ Python 3; TensorFlow >= 1.4.0
+ Install [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV.
+ Pre-trained [ResNet50 model](https://goo.gl/6XjK9V) from tensorpack model zoo.
+ Pre-trained [ResNet model](https://goo.gl/6XjK9V) from tensorpack model zoo.
+ COCO data. It assumes the following directory structure:
```
DIR/
......@@ -23,36 +24,41 @@ DIR/
## Usage
Change `BASEDIR` in `config.py` to `/path/to/DIR` as described above.
Change config:
1. Set `BASEDIR` in `config.py` to `/path/to/DIR` as described above.
2. Set `MODE_MASK` to switch Faster-RCNN or Mask-RCNN.
To train:
Train:
```
./train.py --load /path/to/ImageNet-ResNet50.npz
```
The code is only for training with 1, 2, 4 or 8 GPUs.
Otherwise, you probably need different hyperparameters for the same performance.
To predict on an image (and show output in a window):
Predict on an image (and show output in a window):
```
./train.py --predict input.jpg --load /path/to/model
```
To evaluate the performance (pretrained models can be downloaded in [model zoo](http://models.tensorpack.com/FasterRCNN):
Evaluate the performance of a model and save to json.
(A pretrained model can be downloaded in [model zoo](http://models.tensorpack.com/FasterRCNN):
```
./train.py --evaluate output.json --load /path/to/model
```
## Results
Trained on trainval35k and evaluated on minival, got the following results:
mAP@IoU=0.50:0.95:
Models are trained on trainval35k and evaluated on minival using mAP@IoU=0.50:0.95.
MaskRCNN results contain both bbox and segm mAP.
|Backbone | `FASTRCNN_BATCH` | mAP | Time |
| - | - | - | - |
| Res50 | 256 | 34.4 | 49h on 8 TitanX |
| Res50 | 64 | 33.0 | 22h on 8 P100 |
|Backbone | `FASTRCNN_BATCH` | resolution | mAP (bbox/segm) | Time |
| - | - | - | - | - |
| Res50 | 64 | (600, 1024) | 33.0 | 22h on 8 P100 |
| Res50 | 256 | (600, 1024) | 34.4 | 49h on 8 TitanX |
| Res50 | 512 | (800, 1333) | 35.6 | 55h on 8 P100|
| Res50 | 512 | (800, 1333) | 36.9/32.3 | 59h on 8 P100|
The hyperparameters are not carefully tuned. You can probably get better performance by e.g. training longer.
Note that these models are trained with a longer learning schedule than the paper.
## Notes
......
......@@ -4,6 +4,7 @@
import tensorflow as tf
from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, BNReLU)
......@@ -88,6 +89,7 @@ def pretrained_resnet_conv4(image, num_blocks):
return l
@auto_reuse_variable_scope
def resnet_conv5(image, num_block):
with argscope([Conv2D, BatchNorm], data_format='NCHW'), \
argscope(Conv2D, nl=tf.identity, use_bias=False), \
......
......@@ -4,6 +4,9 @@
import numpy as np
# mode flags ---------------------
MODE_MASK = False
# dataset -----------------------
BASEDIR = '/path/to/your/COCO/DIR'
TRAIN_DATASET = ['train2014', 'valminusminival2014']
......@@ -38,7 +41,6 @@ RPN_MIN_SIZE = 0
RPN_PROPOSAL_NMS_THRESH = 0.7
TRAIN_PRE_NMS_TOPK = 12000
TRAIN_POST_NMS_TOPK = 2000
# boxes overlapping crowd will be ignored.
CROWD_OVERLAP_THRES = 0.7
......
......@@ -4,6 +4,7 @@
import cv2
import numpy as np
import copy
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import (
......@@ -231,8 +232,9 @@ def get_train_dataflow(add_mask=False):
ret = [im, fm_labels, fm_boxes, boxes, klass]
# masks
segmentation = img.get('segmentation', None)
if segmentation is not None:
if add_mask:
# augmentation will modify the polys in-place
segmentation = copy.deepcopy(img.get('segmentation', None))
segmentation = [segmentation[k] for k in range(len(segmentation)) if not is_crowd[k]]
assert len(segmentation) == len(boxes)
......@@ -266,7 +268,7 @@ def get_eval_dataflow():
assert im is not None, fname
return im
ds = MapDataComponent(ds, f, 0)
# ds = PrefetchDataZMQ(ds, 1)
ds = PrefetchDataZMQ(ds, 1)
return ds
......
......@@ -5,11 +5,14 @@
import tqdm
import os
from collections import namedtuple
import numpy as np
import cv2
from tensorpack.utils.utils import get_tqdm_kwargs
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import pycocotools.mask as cocomask
from coco import COCOMeta
from common import CustomResize
......@@ -17,14 +20,41 @@ import config
DetectionResult = namedtuple(
'DetectionResult',
['class_id', 'box', 'score'])
['box', 'score', 'class_id', 'mask'])
"""
class_id: int, 1~NUM_CLASS
box: 4 float
score: float
class_id: int, 1~NUM_CLASS
mask: None, or a binary image of the original image shape
"""
def fill_full_mask(box, mask, shape):
"""
Args:
box: 4 float
mask: MxM floats
shape: h,w
"""
# int() is floor
# box fpcoor=0.0 -> intcoor=0.0
x0, y0 = list(map(int, box[:2] + 0.5))
# box fpcoor=h -> intcoor=h-1, inclusive
x1, y1 = list(map(int, box[2:] - 0.5)) # inclusive
x1 = max(x0, x1) # require at least 1x1
y1 = max(y0, y1)
w = x1 + 1 - x0
h = y1 + 1 - y0
# rounding errors could happen here, because masks were not originally computed for this shape.
# but it's hard to do better, because the network does not know the "original" scale
mask = (cv2.resize(mask, (w, h)) > 0.5).astype('uint8')
ret = np.zeros(shape, dtype='uint8')
ret[y0:y1 + 1, x0:x1 + 1] = mask
return ret
def detect_one_image(img, model_func):
"""
Run detection on one image, using the TF callable.
......@@ -32,19 +62,30 @@ def detect_one_image(img, model_func):
Args:
img: an image
model_func: a callable from TF model, takes [image] and returns (probs, boxes)
model_func: a callable from TF model,
takes image and returns (boxes, probs, labels, [masks])
Returns:
[DetectionResult]
"""
orig_shape = img.shape[:2]
resizer = CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE)
resized_img = resizer.augment(img)
scale = (resized_img.shape[0] * 1.0 / img.shape[0] + resized_img.shape[1] * 1.0 / img.shape[1]) / 2
boxes, probs, labels = model_func(resized_img)
boxes, probs, labels, *masks = model_func(resized_img)
boxes = boxes / scale
results = [DetectionResult(*args) for args in zip(labels, boxes, probs)]
if masks:
# has mask
full_masks = [fill_full_mask(box, mask, orig_shape)
for box, mask in zip(boxes, masks[0])]
masks = full_masks
else:
# fill with none
masks = [None] * len(boxes)
results = [DetectionResult(*args) for args in zip(boxes, probs, labels, masks)]
return results
......@@ -62,16 +103,26 @@ def eval_on_dataflow(df, detect_func):
with tqdm.tqdm(total=df.size(), **get_tqdm_kwargs()) as pbar:
for img, img_id in df.get_data():
results = detect_func(img)
for classid, box, score in results:
cat_id = COCOMeta.class_id_to_category_id[classid]
for r in results:
box = r.box
cat_id = COCOMeta.class_id_to_category_id[r.class_id]
box[2] -= box[0]
box[3] -= box[1]
all_results.append({
res = {
'image_id': img_id,
'category_id': cat_id,
'bbox': list(map(lambda x: float(round(x, 1)), box)),
'score': float(round(score, 2)),
})
'score': float(round(r.score, 2)),
}
# also append segmentation to results
if r.mask is not None:
rle = cocomask.encode(
np.array(r.mask[:, :, None], order='F'))[0]
rle['counts'] = rle['counts'].decode('ascii')
res['segmentation'] = rle
all_results.append(res)
pbar.update(1)
return all_results
......@@ -84,9 +135,13 @@ def print_evaluation_scores(json_file):
'instances_{}.json'.format(config.VAL_DATASET))
coco = COCO(annofile)
cocoDt = coco.loadRes(json_file)
imgIds = sorted(coco.getImgIds())
cocoEval = COCOeval(coco, cocoDt, 'bbox')
cocoEval.params.imgIds = imgIds
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
if config.MODE_MASK:
cocoEval = COCOeval(coco, cocoDt, 'segm')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
......@@ -8,7 +8,7 @@ from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import (
Conv2D, FullyConnected, GlobalAvgPooling, layer_register)
Conv2D, FullyConnected, GlobalAvgPooling, layer_register, Deconv2D)
from utils.box_ops import pairwise_iou
import config
......@@ -90,6 +90,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
precision = tf.to_float(tf.truediv(pos_prediction_corr, nr_pos_prediction))
precision = tf.where(tf.equal(nr_pos_prediction, 0), 0.0, precision, name='precision_th{}'.format(th))
summaries.append(precision)
add_moving_summary(*summaries)
label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.to_float(valid_anchor_labels), logits=valid_label_logits)
......@@ -105,7 +106,7 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
box_loss,
tf.cast(nr_valid, tf.float32), name='box_loss')
add_moving_summary(*([label_loss, box_loss, nr_valid, nr_pos] + summaries))
add_moving_summary(label_loss, box_loss, nr_valid, nr_pos)
return label_loss, box_loss
......@@ -126,8 +127,8 @@ def decode_bbox_target(box_predictions, anchors):
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
waha = tf.to_float(anchors_x2y2 - anchors_x1y1)
xaya = tf.to_float(anchors_x2y2 + anchors_x1y1) * 0.5
waha = anchors_x2y2 - anchors_x1y1
xaya = (anchors_x2y2 + anchors_x1y1) * 0.5
wbhb = tf.exp(tf.minimum(
box_pred_twth, config.BBOX_DECODE_CLIP)) * waha
......@@ -150,16 +151,15 @@ def encode_bbox_target(boxes, anchors):
"""
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
waha = tf.to_float(anchors_x2y2 - anchors_x1y1)
xaya = tf.to_float(anchors_x2y2 + anchors_x1y1) * 0.5
waha = anchors_x2y2 - anchors_x1y1
xaya = (anchors_x2y2 + anchors_x1y1) * 0.5
boxes_x1y1x2y2 = tf.reshape(boxes, (-1, 2, 2))
boxes_x1y1, boxes_x2y2 = tf.split(boxes_x1y1x2y2, 2, axis=1)
wbhb = tf.to_float(boxes_x2y2 - boxes_x1y1)
xbyb = tf.to_float(boxes_x2y2 + boxes_x1y1) * 0.5
wbhb = boxes_x2y2 - boxes_x1y1
xbyb = (boxes_x2y2 + boxes_x1y1) * 0.5
# Note that here not all boxes are valid. Some may be zero
txty = (xbyb - xaya) / waha
twth = tf.log(wbhb / waha) # may contain -inf for invalid boxes
encoded = tf.concat([txty, twth], axis=1) # (-1x2x2)
......@@ -292,6 +292,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
ret_labels = tf.concat(
[tf.gather(gt_labels, fg_inds_wrt_gt),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels')
# stop the gradient -- they are meant to be ground-truth
return tf.stop_gradient(ret_boxes), tf.stop_gradient(ret_labels), fg_inds_wrt_gt
......@@ -487,3 +488,60 @@ def fastrcnn_predictions(boxes, probs):
filtered_selection = tf.gather(selected_indices, topk_indices)
filtered_selection = tf.reverse(filtered_selection, axis=[1], name='filtered_indices')
return filtered_selection, topk_probs
@layer_register(log_shape=True)
def maskrcnn_head(feature, num_class):
"""
Args:
feature (NxCx7x7):
num_classes(int): num_category + 1
Returns:
mask_logits (N x num_category x 14 x 14):
"""
with argscope([Conv2D, Deconv2D], data_format='NCHW',
W_init=tf.variance_scaling_initializer(
scale=2.0, mode='fan_in', distribution='normal')):
l = Deconv2D('deconv', feature, 256, 2, stride=2, nl=tf.nn.relu)
l = Conv2D('conv', l, num_class - 1, 1)
return l
@under_name_scope()
def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
"""
Args:
mask_logits: #fg x #category x14x14
fg_labels: #fg, in 1~#class
fg_target_masks: #fgx14x14, int
"""
num_fg = tf.size(fg_labels)
indices = tf.stack([tf.range(num_fg), tf.to_int32(fg_labels) - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fgx14x14
mask_probs = tf.sigmoid(mask_logits)
# add some training visualizations to tensorboard
with tf.name_scope('mask_viz'):
viz = tf.concat([fg_target_masks, mask_probs], axis=1)
viz = tf.expand_dims(viz, 3)
viz = tf.cast(viz * 255, tf.uint8, name='viz')
tf.summary.image('mask_truth|pred', viz, max_outputs=10)
loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=fg_target_masks, logits=mask_logits)
loss = tf.reduce_mean(loss, name='maskrcnn_loss')
pred_label = mask_probs > 0.5
truth_label = fg_target_masks > 0.5
accuracy = tf.reduce_mean(
tf.to_float(tf.equal(pred_label, truth_label)),
name='accuracy')
pos_accuracy = tf.logical_and(
tf.equal(pred_label, truth_label),
tf.equal(truth_label, True))
pos_accuracy = tf.reduce_mean(tf.to_float(pos_accuracy), name='pos_accuracy')
fg_pixel_ratio = tf.reduce_mean(tf.to_float(truth_label), name='fg_pixel_ratio')
add_moving_summary(loss, accuracy, fg_pixel_ratio, pos_accuracy)
return loss
......@@ -28,7 +28,8 @@ from model import (
clip_boxes, decode_bbox_target, encode_bbox_target, crop_and_resize,
rpn_head, rpn_losses,
generate_rpn_proposals, sample_fast_rcnn_targets, roi_align,
fastrcnn_head, fastrcnn_losses, fastrcnn_predictions)
fastrcnn_head, fastrcnn_losses, fastrcnn_predictions,
maskrcnn_head, maskrcnn_loss)
from data import (
get_train_dataflow, get_eval_dataflow,
get_all_anchors)
......@@ -47,15 +48,26 @@ def get_batch_factor():
return 8 // nr_gpu
def get_model_output_names():
ret = ['final_boxes', 'final_probs', 'final_labels']
if config.MODE_MASK:
ret.append('final_masks')
return ret
class Model(ModelDesc):
def _get_inputs(self):
return [
ret = [
InputDesc(tf.float32, (None, None, 3), 'image'),
InputDesc(tf.int32, (None, None, config.NUM_ANCHOR), 'anchor_labels'),
InputDesc(tf.float32, (None, None, config.NUM_ANCHOR, 4), 'anchor_boxes'),
InputDesc(tf.float32, (None, 4), 'gt_boxes'),
InputDesc(tf.int64, (None,), 'gt_labels'), # all > 0
]
InputDesc(tf.int64, (None,), 'gt_labels')] # all > 0
if config.MODE_MASK:
ret.append(
InputDesc(tf.uint8, (None, None, None), 'gt_masks')
) # NR_GT x height x width
return ret
def _preprocess(self, image):
image = tf.expand_dims(image, 0)
......@@ -79,7 +91,10 @@ class Model(ModelDesc):
def _build_graph(self, inputs):
is_training = get_current_tower_context().is_training
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
if config.MODE_MASK:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs
else:
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
fm_anchors = self._get_anchors(image)
image = self._preprocess(image) # 1CHW
image_shape2d = tf.shape(image)[2:]
......@@ -104,8 +119,19 @@ class Model(ModelDesc):
boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head('fastrcnn', feature_fastrcnn, config.NUM_CLASS)
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
def ff_true():
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head('fastrcnn', feature_fastrcnn, config.NUM_CLASS)
return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits
def ff_false():
ncls = config.NUM_CLASS
return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4])
feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond(
tf.size(boxes_on_featuremap) > 0, ff_true, ff_false)
if is_training:
# rpn loss
......@@ -116,6 +142,7 @@ class Model(ModelDesc):
fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples
fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample)
# TODO move to models
with tf.name_scope('fg_sample_patch_viz'):
fg_sampled_patches = crop_and_resize(
image, fg_sampled_boxes,
......@@ -132,13 +159,30 @@ class Model(ModelDesc):
encoded_boxes,
tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample))
if config.MODE_MASK:
# maskrcnn loss
fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample)
fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample)
mask_logits = maskrcnn_head('maskrcnn', fg_feature, config.NUM_CLASS) # #fg x #cat x 14x14
gt_masks_for_fg = tf.gather(gt_masks, fg_inds_wrt_gt) # nfg x H x W
target_masks_for_fg = crop_and_resize(
tf.expand_dims(gt_masks_for_fg, 1),
fg_sampled_boxes,
tf.range(tf.size(fg_inds_wrt_gt)), 14) # nfg x 1x14x14
target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets')
mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg)
else:
mrcnn_loss = 0.0
wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fastrcnn)/.*W',
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(1e-4), name='wd_cost')
self.cost = tf.add_n([
rpn_label_loss, rpn_box_loss,
fastrcnn_label_loss, fastrcnn_box_loss,
mrcnn_loss,
wd_cost], 'total_cost')
add_moving_summary(self.cost, wd_cost)
......@@ -153,8 +197,22 @@ class Model(ModelDesc):
# indices: Nx2. Each index into (#proposal, #category)
pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs)
final_probs = tf.identity(final_probs, 'final_probs')
tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
tf.add(pred_indices[:, 1], 1, name='final_labels')
final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes')
final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')
if config.MODE_MASK:
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
def f1():
roi_resized = roi_align(featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
feature_maskrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1])
mask_logits = maskrcnn_head(
'maskrcnn', feature_maskrcnn, config.NUM_CLASS) # #result x #cat x 14x14
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14
return tf.sigmoid(final_mask_logits)
final_masks = tf.cond(tf.size(final_probs) > 0, f1, lambda: tf.zeros([0, 14, 14]))
tf.identity(final_masks, name='final_masks')
def _get_optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.003, trainable=False)
......@@ -171,6 +229,9 @@ class Model(ModelDesc):
def visualize(model_path, nr_visualize=50, output_dir='output'):
df = get_train_dataflow() # we don't visualize mask stuff
df.reset_state()
pred = OfflinePredictor(PredictConfig(
model=Model(),
session_init=get_model_loader(model_path),
......@@ -183,8 +244,6 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
'final_probs',
'final_labels',
]))
df = get_train_dataflow()
df.reset_state()
if os.path.isdir(output_dir):
shutil.rmtree(output_dir)
......@@ -237,7 +296,7 @@ class EvalCallback(Callback):
def _setup_graph(self):
self.pred = self.trainer.get_predictor(
['image'],
['final_boxes', 'final_probs', 'final_labels'])
get_model_output_names())
self.df = get_eval_dataflow()
def _before_train(self):
......@@ -288,11 +347,7 @@ if __name__ == '__main__':
model=Model(),
session_init=get_model_loader(args.load),
input_names=['image'],
output_names=[
'final_boxes',
'final_probs',
'final_labels',
]))
output_names=get_model_output_names()))
if args.evaluate:
assert args.evaluate.endswith('.json')
offline_evaluate(pred, args.evaluate)
......@@ -308,7 +363,7 @@ if __name__ == '__main__':
cfg = TrainConfig(
model=Model(),
data=QueueInput(get_train_dataflow()),
data=QueueInput(get_train_dataflow(add_mask=config.MODE_MASK)),
callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=5),
# linear warmup
......
......@@ -72,11 +72,16 @@ def draw_final_outputs(img, results):
return img
tags = []
for label, _, score in results:
for r in results:
tags.append(
"{},{:.2f}".format(config.CLASS_NAMES[label], score))
boxes = np.asarray([x.box for x in results])
return viz.draw_boxes(img, boxes, tags)
"{},{:.2f}".format(config.CLASS_NAMES[r.class_id], r.score))
boxes = np.asarray([r.box for r in results])
ret = viz.draw_boxes(img, boxes, tags)
for r in results:
if r.mask is not None:
ret = draw_mask(ret, r.mask)
return ret
def draw_mask(im, mask, alpha=0.5, color=None):
......
......@@ -17,7 +17,7 @@ Without a setting and performance comparable to someone else, you don't know if
| Name | Performance |
| --- | --- |
| Train [ResNet](ResNet) and [ShuffleNet](ShuffleNet) on ImageNet | reproduce paper |
| [Train Faster-RCNN on COCO](FasterRCNN) | reproduce paper |
| [Train Faster-RCNN / Mask-RCNN on COCO](FasterRCNN) | reproduce paper |
| [DoReFa-Net: training binary / low-bitwidth CNN on ImageNet](DoReFa-Net) | reproduce paper |
| [Generative Adversarial Network(GAN) variants](GAN), including DCGAN, InfoGAN, <br/> Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN | visually reproduce |
| [Inception-BN and InceptionV3](Inception) | reproduce reference code |
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment