Commit 23446308 authored by Yuxin Wu's avatar Yuxin Wu

initial commit of FasterRCNN

parent 7c3b404a
# Faster-RCNN on COCO
This example aimes to provide a minimal Multi-GPU implementation (<1000 lines) of ResNet50-Faster-RCNN on COCO.
## Dependencies
+ TensorFlow nightly.
+ Install [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV.
+ Pre-trained [ResNet50 model](https://goo.gl/6XjK9V) from tensorpack model zoo.
+ COCO data. It assumes the following directory structure:
```
DIR/
annotations/
instances_train2014.json
instances_val2014.json
instances_minival2014.json
instances_valminusminival2014.json
train2014/
COCO_train2014_*.jpg
val2014/
COCO_val2014_*.jpg
```
`minival` and `valminusminival` are optional. You can download them
[here](https://github.com/rbgirshick/py-faster-rcnn/blob/master/data/README.md).
## Usage
Change `BASEDIR` in `config.py` to `/path/to/DIR` as described above.
To train:
```
./train.py --load /path/to/ImageNet-ResNet50.npz
```
The code is written for training with __8 GPUs__. Otherwise the performance won't be as good.
To predict on an image (and show output in a window):
```
./train.py --predict input.jpg
```
## Results
+ trainval35k/minival, FASTRCNN_BATCH=256: 32.9
+ trainval35k/minival, FASTRCNN_BATCH=64: 31.7. Takes less than one day on 8 Maxwell TitanX.
The hyperparameters are not carefully tuned. You can probably get better performance by e.g. training longer.
## Files
This is an minimal implementation that simply contains these files:
+ coco.py: load COCO data
+ data.py: prepare data for training
+ common.py: some common data preparation utilities
+ basemodel.py: implement resnet
+ model.py: implement faster-rcnn
+ viz.py: visualization utilities
+ utils/: third-party helper functions
+ train.py: main training script
+ eval.py: utilities for evaluation
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: basemodel.py
import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary, add_activation_summary
from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, BNReLU, GlobalAvgPooling, FullyConnected)
def image_preprocess(image, bgr=True):
with tf.name_scope('image_preprocess'):
if image.dtype.base_dtype != tf.float32:
image = tf.cast(image, tf.float32)
image = image * (1.0 / 255)
mean = [0.485, 0.456, 0.406] # rgb
std = [0.229, 0.224, 0.225]
if bgr:
mean = mean[::-1]
std = std[::-1]
image_mean = tf.constant(mean, dtype=tf.float32)
image_std = tf.constant(std, dtype=tf.float32)
image = (image - image_mean) / image_std
return image
def get_bn(zero_init=False):
if zero_init:
return lambda x, name: BatchNorm('bn', x, gamma_init=tf.zeros_initializer())
else:
return lambda x, name: BatchNorm('bn', x)
def resnet_shortcut(l, n_out, stride, nl=tf.identity):
data_format = get_arg_scope()['Conv2D']['data_format']
n_in = l.get_shape().as_list()[1 if data_format == 'NCHW' else 3]
if n_in != n_out: # change dimension when channel is not the same
if stride == 2 and 'group3' not in tf.get_variable_scope().name:
l = l[:,:,:-1,:-1]
return Conv2D('convshortcut', l, n_out, 1,
stride=stride, padding='VALID', nl=nl)
else:
return Conv2D('convshortcut', l, n_out, 1,
stride=stride, nl=nl)
else:
return l
def resnet_bottleneck(l, ch_out, stride):
l, shortcut = l, l
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
if stride == 2 and 'group3' not in tf.get_variable_scope().name:
l = tf.pad(l, [[0,0],[0,0],[0,1],[0,1]])
l = Conv2D('conv2', l, ch_out, 3, stride=2, nl=BNReLU, padding='VALID')
else:
l = Conv2D('conv2', l, ch_out, 3, stride=stride, nl=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, nl=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def resnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features,
stride if i == 0 else 1)
# end of each block need an activation
l = tf.nn.relu(l)
return l
def pretrained_resnet_conv4(image, num_blocks):
assert len(num_blocks) == 3
with argscope([Conv2D, MaxPooling, BatchNorm], data_format='NCHW'), \
argscope(Conv2D, nl=tf.identity, use_bias=False), \
argscope(BatchNorm, use_local_stat=False):
l = tf.pad(image, [[0,0],[0,0],[2,3],[2,3]])
l = Conv2D('conv0', l, 64, 7, stride=2, nl=BNReLU, padding='VALID')
l = tf.pad(l, [[0,0],[0,0],[0,1],[0,1]])
l = MaxPooling('pool0', l, shape=3, stride=2, padding='VALID')
l = resnet_group(l, 'group0', resnet_bottleneck, 64, num_blocks[0], 1)
# TODO replace var by const to enable folding
l = tf.stop_gradient(l)
l = resnet_group(l, 'group1', resnet_bottleneck, 128, num_blocks[1], 2)
l = resnet_group(l, 'group2', resnet_bottleneck, 256, num_blocks[2], 2)
# 16x downsampling up to now
return l
def resnet_conv5(image):
with argscope([Conv2D, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \
argscope(Conv2D, nl=tf.identity, use_bias=False), \
argscope(BatchNorm, use_local_stat=False):
# 14x14:
l = resnet_group(image, 'group3', resnet_bottleneck, 512, 3, stride=2)
l = GlobalAvgPooling('gap', l)
return l
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: coco.py
import numpy as np
import os
import six
from termcolor import colored
from tabulate import tabulate
from tensorpack.dataflow import DataFromList
from tensorpack.utils import logger
from tensorpack.utils.rect import FloatBox
from tensorpack.utils.timer import timed_operation
from pycocotools.coco import COCO
__all__ = ['COCODetection', 'COCOMeta']
COCO_NUM_CATEGORY = 80
class _COCOMeta(object):
INSTANCE_TO_BASEDIR = {
'train2014': 'train2014',
'val2014': 'val2014',
'valminusminival2014': 'val2014',
'minival2014': 'val2014',
'test2014': 'test2014'
}
def valid(self):
return hasattr(self, 'cat_names')
def create(self, cat_ids, cat_names):
"""
cat_ids: list of ids
cat_names: list of names
"""
assert not self.valid()
assert len(cat_ids) == COCO_NUM_CATEGORY and len(cat_names) == COCO_NUM_CATEGORY
self.cat_names = cat_names
self.class_names = ['BG'] + self.cat_names
# background has class id of 0
self.category_id_to_class_id = {
v: i + 1 for i, v in enumerate(cat_ids)}
self.class_id_to_category_id = {
v: k for k, v in self.category_id_to_class_id.items()}
COCOMeta = _COCOMeta()
class COCODetection(object):
def __init__(self, basedir, name):
assert name in COCOMeta.INSTANCE_TO_BASEDIR.keys(), name
self.name = name
self._imgdir = os.path.join(basedir, COCOMeta.INSTANCE_TO_BASEDIR[name])
assert os.path.isdir(self._imgdir), self._imgdir
annotation_file = os.path.join(
basedir, 'annotations/instances_{}.json'.format(name))
assert os.path.isfile(annotation_file), annotation_file
self.coco = COCO(annotation_file)
# initialize the meta
cat_ids = self.coco.getCatIds()
cat_names = [c['name'] for c in self.coco.loadCats(cat_ids)]
if not COCOMeta.valid():
COCOMeta.create(cat_ids, cat_names)
else:
assert COCOMeta.cat_names == cat_names
logger.info("Instances loaded from {}.".format(annotation_file))
def load(self, add_gt=True):
"""
Args:
add_gt: whether to add ground truth annotations to the dicts
Returns:
a list of dict, each has keys including:
height, width, id, file_name,
and (if add_gt is True) boxes, class, is_crowd
"""
with timed_operation('Load Groundtruth Boxes for {}'.format(self.name)):
img_ids = self.coco.getImgIds()
img_ids.sort()
# list of dict, each has keys: height,width,id,file_name
imgs = self.coco.loadImgs(img_ids)
for img in imgs:
self._use_absolute_file_name(img)
if add_gt:
self._add_detection_gt(img)
return imgs
def _use_absolute_file_name(self, img):
"""
Change relative filename to abosolute file name.
"""
img['file_name'] = os.path.join(
self._imgdir, img['file_name'])
assert os.path.isfile(img['file_name']), img['file_name']
def _add_detection_gt(self, img):
"""
Add 'boxes', 'class', 'is_crowd' of this image to the dict, used by detection.
"""
ann_ids = self.coco.getAnnIds(imgIds=img['id'], iscrowd=None)
objs = self.coco.loadAnns(ann_ids)
# clean-up boxes
valid_objs = []
width = img['width']
height = img['height']
for obj in objs:
if obj.get('ignore', 0) == 1:
continue
x1, y1, w, h = obj['bbox']
# bbox is originally in float
# NOTE: assume in data that x1/y1 means upper-left corner and w/h means true w/h
# assume that (0.0, 0.0) is upper-left corner of the first pixel
box = FloatBox(float(x1), float(y1),
float(x1 + w), float(y1 + h))
box.clip_by_shape([height, width])
# Require non-zero seg area and more than 1x1 box size
if obj['area'] > 0 and box.is_box() and box.area() >= 4:
obj['bbox'] = [box.x1, box.y1, box.x2, box.y2]
valid_objs.append(obj)
# all geometrically-valid boxes are returned
boxes = np.asarray([obj['bbox'] for obj in valid_objs], dtype='float32') # (n, 4)
cls = np.asarray([
COCOMeta.category_id_to_class_id[obj['category_id']]
for obj in valid_objs], dtype='int32') # (n,)
is_crowd = np.asarray([obj['iscrowd'] for obj in valid_objs], dtype='int8')
# add the keys
img['boxes'] = boxes # nx4
img['class'] = cls # n, always >0
img['is_crowd'] = is_crowd # n,
def print_class_histogram(self, imgs):
nr_class = len(COCOMeta.class_names)
hist_bins = np.arange(nr_class + 1)
# Histogram of ground-truth objects
gt_hist = np.zeros((nr_class,), dtype=np.int)
for entry in imgs:
# filter crowd?
gt_inds = np.where(
(entry['class'] > 0) & (entry['is_crowd'] == 0))[0]
gt_classes = entry['class'][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = [[COCOMeta.class_names[i], v] for i, v in enumerate(gt_hist)]
data.append(['total', sum([x[1] for x in data])])
table = tabulate(data, headers=['class', '#box'], tablefmt='pipe')
logger.info("Ground-Truth Boxes:\n" + colored(table, 'cyan'))
@staticmethod
def load_many(basedir, names, add_gt=True):
"""
Load and merges several instance files together.
"""
if not isinstance(names, (list, tuple)):
names = [names]
ret = []
for n in names:
coco = COCODetection(basedir, n)
ret.extend(coco.load(add_gt))
return ret
if __name__ == '__main__':
c = COCODetection('train')
gt_boxes = c.load()
print("#Images:", len(gt_boxes))
c.print_class_histogram(bb)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
import numpy as np
import cv2
from tensorpack.dataflow import RNGDataFlow
from tensorpack.dataflow.imgaug import transform
from tensorpack.utils import logger
import config
class DataFromListOfDict(RNGDataFlow):
def __init__(self, lst, keys, shuffle=False):
self._lst = lst
self._keys = keys
self._shuffle = shuffle
self._size = len(lst)
def size(self):
return self._size
def get_data(self):
if self._shuffle:
self.rng.shuffle(self._lst)
for dic in self._lst:
dp = [dic[k] for k in self._keys]
yield dp
class CustomResize(transform.TransformAugmentorBase):
"""
Try resizing the shortest edge to a certain number
while avoiding the longest edge to exceed max_size.
"""
def __init__(self, size, max_size, interp=cv2.INTER_LINEAR):
"""
Args:
size (int): the size to resize the shortest edge to.
max_size (int): maximum allowed longest edge.
"""
self._init(locals())
def _get_augment_params(self, img):
h, w = img.shape[:2]
scale = self.size * 1.0 / min(h, w)
if h < w:
newh, neww = self.size, scale * w
else:
newh, neww = scale * h, self.size
if max(newh, neww) > self.max_size:
scale = self.max_size * 1.0 / max(newh, neww)
newh = newh * scale
neww = neww * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return transform.ResizeTransform(h, w, newh, neww, self.interp)
def box_to_point8(boxes):
"""
Args:
boxes: nx4
Returns:
(nx4)x2
"""
b = boxes[:,[0,1,2,3,0,3,2,1]]
b = b.reshape((-1, 2))
return b
def point8_to_box(points):
"""
Args:
points: (nx4)x2
Returns:
nx4 boxes (x1y1x2y2)
"""
p = points.reshape((-1, 4, 2))
minxy = p.min(axis=1) #nx2
maxxy = p.max(axis=1) #nx2
return np.concatenate((minxy, maxxy), axis=1)
def clip_boxes(boxes, shape):
"""
Args:
boxes: nx4, float
shape: h, w
"""
h, w = shape
boxes[:,[0,1]] = np.maximum(boxes[:,[0,1]], 0)
boxes[:,2] = np.minimum(boxes[:,2], w)
boxes[:,3] = np.minimum(boxes[:,3], h)
return boxes
def print_config():
logger.info("Config: ------------------------------------------")
for k in dir(config):
if k == k.upper():
logger.info("{} = {}".format(k, getattr(config, k)))
logger.info("--------------------------------------------------")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: config.py
import numpy as np
# dataset -----------------------
BASEDIR = '/path/to/your/COCO/DIR'
TRAIN_DATASET = ['train2014', 'valminusminival2014']
VAL_DATASET = 'minival2014' # only support evaluation on one dataset
NUM_CLASS = 81
# preprocessing --------------------
SHORT_EDGE_SIZE = 600
MAX_SIZE = 1024
# anchors -------------------------
ANCHOR_STRIDE = 16
# sqrtarea of the anchor box
ANCHOR_SIZES = (32, 64, 128, 256, 512)
ANCHOR_RATIOS = (0.5, 1., 2.)
NR_ANCHOR = len(ANCHOR_SIZES) * len(ANCHOR_RATIOS)
POSITIVE_ANCHOR_THRES = 0.7
NEGATIVE_ANCHOR_THRES = 0.3
# rpn training -------------------------
# keep fg ratio in a batch in this range
RPN_FG_RATIO = 0.5
RPN_BATCH_PER_IM = 256
RPN_MIN_SIZE = 0
RPN_PROPOSAL_NMS_THRESH = 0.7
TRAIN_PRE_NMS_TOPK = 12000
TRAIN_POST_NMS_TOPK = 2000
# boxes overlapping crowd will be ignored.
CROWD_OVERLAP_THRES = 0.7
# fastrcnn training ---------------------
FASTRCNN_BATCH_PER_IM = 64
FASTRCNN_BBOX_REG_WEIGHTS = np.array([10, 10, 5, 5], dtype='float32')
FASTRCNN_FG_THRESH = 0.5
# keep fg ratio in a batch in this range
FASTRCNN_FG_RATIO = (0.1, 0.25)
# testing -----------------------
TEST_PRE_NMS_TOPK= 6000
TEST_POST_NMS_TOPK= 1000
FASTRCNN_NMS_THRESH = 0.5
RESULT_SCORE_THRESH = 0.05
RESULTS_PER_IM = 100
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: data.py
import cv2
import os
import numpy as np
import logging
from tensorpack.utils import logger
from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import (
ProxyDataFlow, MapData, imgaug, TestDataSpeed,
AugmentImageComponents, MapDataComponent)
import tensorpack.utils.viz as tpviz
from tensorpack.utils.viz import interactive_imshow
from coco import COCODetection
from utils.generate_anchors import generate_anchors
from utils.box_ops import get_iou_callable
from common import (
DataFromListOfDict, CustomResize,
box_to_point8, point8_to_box)
import config
class MalformedData(BaseException):
pass
@memoized
def get_all_anchors():
"""
Get all anchors in the largest possible image, shifted, floatbox
Returns:
anchors: SxSxNR_ANCHORx4, where S == MAX_SIZE//STRIDE, floatbox
"""
# Generates a NAx4 matrix of anchor boxes in (x1, y1, x2, y2) format. Anchors
# are centered on stride / 2, have (approximate) sqrt areas of the specified
# sizes, and aspect ratios as given.
cell_anchors = generate_anchors(
config.ANCHOR_STRIDE,
scales=np.array(config.ANCHOR_SIZES, dtype=np.float) / config.ANCHOR_STRIDE,
ratios=np.array(config.ANCHOR_RATIOS, dtype=np.float))
# anchors are intbox here.
# anchors at featuremap [0,0] are centered at fpcoor (8,8) (half of stride)
field_size = config.MAX_SIZE // config.ANCHOR_STRIDE
shifts = np.arange(0, field_size) * config.ANCHOR_STRIDE
shift_x, shift_y = np.meshgrid(shifts, shifts)
shift_x = shift_x.flatten()
shift_y = shift_y.flatten()
shifts = np.vstack((shift_x, shift_y, shift_x, shift_y)).transpose()
# Kx4, K = field_size * field_size
K = shifts.shape[0]
A = cell_anchors.shape[0]
field_of_anchors = (
cell_anchors.reshape((1, A, 4)) +
shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
field_of_anchors = field_of_anchors.reshape((field_size, field_size, A, 4))
# FSxFSxAx4
assert np.all(field_of_anchors == field_of_anchors.astype('int32'))
field_of_anchors = field_of_anchors.astype('float32')
field_of_anchors[:,:,:,[2,3]] += 1
return field_of_anchors
def get_anchor_labels(anchors, gt_boxes, crowd_boxes):
"""
Label each anchor as fg/bg/ignore.
Args:
anchors: Ax4 float
gt_boxes: Bx4 float
crowd_boxes: Cx4 float
Returns:
anchor_labels: (A,) int. Each element is {-1, 0, 1}
anchor_boxes: Ax4. Contains the target gt_box for each anchor when the anchor is fg.
"""
# This function will modify labels and return the filtered inds
def filter_box_label(labels, value, max_num):
curr_inds = np.where(labels == value)[0]
if len(curr_inds) > max_num:
disable_inds = np.random.choice(
curr_inds, size=(len(curr_inds) - max_num),
replace=False)
labels[disable_inds] = -1 # ignore them
curr_inds = np.where(labels == value)[0]
return curr_inds
bbox_iou_float = get_iou_callable()
NA, NB = len(anchors), len(gt_boxes)
assert NB > 0 # empty images should have been filtered already
box_ious = bbox_iou_float(anchors, gt_boxes) # NA x NB
ious_argmax_per_anchor = box_ious.argmax(axis=1) # NA,
ious_max_per_anchor = box_ious.max(axis=1)
ious_max_per_gt = np.amax(box_ious, axis=0, keepdims=True) # 1xNB
# for each gt, find all those anchors (including ties) that has the max ious with it
anchors_with_max_iou_per_gt = np.where(box_ious == ious_max_per_gt)[0]
# Setting NA labels: 1--fg 0--bg -1--ignore
anchor_labels = -np.ones((NA,), dtype='int32') # NA,
# the order of setting neg/pos labels matter
anchor_labels[anchors_with_max_iou_per_gt] = 1
anchor_labels[ious_max_per_anchor >= config.POSITIVE_ANCHOR_THRES] = 1
anchor_labels[ious_max_per_anchor < config.NEGATIVE_ANCHOR_THRES] = 0
# First label all non-ignore candidate boxes which overlap crowd as ignore
if crowd_boxes.size > 0:
cand_inds = np.where(anchor_labels >= 0)[0]
cand_anchors = anchors[cand_inds]
ious = bbox_iou_float(cand_anchors, crowd_boxes)
overlap_with_crowd = cand_inds[ious.max(axis=1) > config.CROWD_OVERLAP_THRES]
anchor_labels[overlap_with_crowd] = -1
# Filter fg labels: ignore some fg if fg is too many
old_num_fg = np.sum(anchor_labels == 1)
target_num_fg = int(config.RPN_BATCH_PER_IM * config.RPN_FG_RATIO)
fg_inds = filter_box_label(anchor_labels, 1, target_num_fg)
# Note that fg could be fewer than the target ratio
# filter bg labels. num_bg is not allowed to be too many
old_num_bg = np.sum(anchor_labels == 0)
if old_num_bg == 0 or len(fg_inds) == 0:
# No valid bg/fg in this image, skip.
# This can happen if, e.g. the image has large crowd.
raise MalformedData("No valid foreground/background for RPN!")
target_num_bg = config.RPN_BATCH_PER_IM - len(fg_inds)
bg_inds = filter_box_label(anchor_labels, 0, target_num_bg)
# Set anchor boxes: the best gt_box for each fg anchor
anchor_boxes = np.zeros((NA, 4), dtype='float32')
fg_boxes = gt_boxes[ious_argmax_per_anchor[fg_inds],:]
anchor_boxes[fg_inds, :] = fg_boxes
return anchor_labels, anchor_boxes
def get_rpn_anchor_input(im, boxes, klass, is_crowd):
"""
Args:
im: an image
boxes: nx4, floatbox, gt. shoudn't be changed
klass: n,
is_crowd: n,
Returns:
The anchor labels and target boxes for each pixel in the featuremap.
fm_labels: fHxfWxNA
fm_boxes: fHxfWxNAx4
"""
boxes = boxes.copy()
ALL_ANCHORS = get_all_anchors()
H, W = im.shape[:2]
featureH, featureW = H // config.ANCHOR_STRIDE, W // config.ANCHOR_STRIDE
def filter_box_inside(im, boxes):
h, w = im.shape[:2]
indices = np.where(
(boxes[:,0] >= 0) &
(boxes[:,1] >= 0) &
(boxes[:,2] <= w) &
(boxes[:,3] <= h))[0]
return indices
crowd_boxes = boxes[is_crowd == 1]
non_crowd_boxes = boxes[is_crowd == 0]
# fHxfWxAx4
featuremap_anchors = ALL_ANCHORS[:featureH,:featureW,:,:]
featuremap_anchors_flatten = featuremap_anchors.reshape((-1, 4))
# only use anchors inside the image
inside_ind = filter_box_inside(im, featuremap_anchors_flatten)
inside_anchors = featuremap_anchors_flatten[inside_ind,:]
anchor_labels, anchor_boxes = get_anchor_labels(inside_anchors, non_crowd_boxes, crowd_boxes)
# Fill them back to original size: fHxfWx1, fHxfWx4
featuremap_labels = -np.ones((featureH * featureW * config.NR_ANCHOR, ), dtype='int32')
featuremap_labels[inside_ind] = anchor_labels
featuremap_labels = featuremap_labels.reshape((featureH, featureW, config.NR_ANCHOR))
featuremap_boxes = np.zeros((featureH * featureW * config.NR_ANCHOR, 4), dtype='float32')
featuremap_boxes[inside_ind, :] = anchor_boxes
featuremap_boxes = featuremap_boxes.reshape((featureH, featureW, config.NR_ANCHOR, 4))
return featuremap_labels, featuremap_boxes
def read_and_augment_images(ds):
def mapf(dp):
fname = dp[0]
im = cv2.imread(fname, cv2.IMREAD_COLOR).astype('float32')
assert im is not None, dp[0]
dp[0] = im
# assume floatbox as input
assert dp[1].dtype == np.float32
dp[1] = box_to_point8(dp[1])
dp.append(fname)
return dp
ds = MapData(ds, mapf)
augs = [CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE),
imgaug.Flip(horiz=True)]
ds = AugmentImageComponents(ds, augs, index=(0,), coords_index=(1,))
def unmapf(points):
boxes = point8_to_box(points)
return boxes
ds = MapDataComponent(ds, unmapf, 1)
return ds
def get_train_dataflow():
imgs = COCODetection.load_many(config.BASEDIR, config.TRAIN_DATASET)
# Valid training images should have at least one fg box.
# But this filter shall not be applied for testing.
imgs = list(filter(lambda img: len(img['boxes']) > 0, imgs)) # log invalid training
ds = DataFromListOfDict(
imgs,
['file_name', 'boxes', 'class', 'is_crowd'], # we need this four keys only
shuffle=True)
ds = read_and_augment_images(ds)
def add_anchor_to_dp(dp):
im, boxes, klass, is_crowd, fname = dp
try:
fm_labels, fm_boxes = get_rpn_anchor_input(im, boxes, klass, is_crowd)
boxes = boxes[is_crowd == 0] # skip crowd boxes in training target
klass = klass[is_crowd == 0]
if not len(boxes):
raise MalformedData("No valid gt_boxes!")
except MalformedData as e:
log_once("Input {} is invalid for training: {}".format(fname, str(e)), 'warn')
return None
return [im, fm_labels, fm_boxes, boxes, klass]
ds = MapData(ds, add_anchor_to_dp)
return ds
def get_eval_dataflow():
imgs = COCODetection.load_many(config.BASEDIR, config.VAL_DATASET, add_gt=False)
# no filter for training
ds = DataFromListOfDict(imgs, ['file_name', 'id'])
def f(fname):
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
return im
ds = MapDataComponent(ds, f, 0)
return ds
if __name__ == '__main__':
#logger.setLevel(logging.DEBUG)
from tensorpack.dataflow import PrintData
ds = get_train_dataflow('/datasets01/COCO/060817')
ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start()
ds.reset_state()
for k in ds.get_data():
pass
#import IPython as IP; IP.embed()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: eval.py
import numpy as np
import tqdm
import cv2
import os
from collections import namedtuple
import tensorflow as tf
from tensorpack.dataflow import MapDataComponent, TestDataSpeed
from tensorpack.tfutils import get_default_sess_config
from tensorpack.utils.argtools import memoized
from tensorpack.utils.utils import get_tqdm_kwargs
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from coco import COCODetection, COCOMeta
from common import clip_boxes, DataFromListOfDict, CustomResize
import config
DetectionResult = namedtuple(
'DetectionResult',
['class_id', 'boxes', 'scores'])
@memoized
def get_tf_nms():
"""
Get a NMS callable.
"""
boxes = tf.placeholder(tf.float32, shape=[None, 4])
scores = tf.placeholder(tf.float32, shape=[None])
indices = tf.image.non_max_suppression(
boxes, scores,
config.RESULTS_PER_IM, config.FASTRCNN_NMS_THRESH)
sess = tf.Session(config=get_default_sess_config())
return sess.make_callable(indices, [boxes, scores])
def nms_fastrcnn_results(boxes, probs):
"""
Args:
boxes: nx4 floatbox in float32
probs: nxC
Returns:
[DetectionResult]
"""
C = probs.shape[1]
boxes = boxes.copy()
boxes_per_class = {}
nms_func = get_tf_nms()
ret = []
for klass in range(1, C):
ids = np.where(probs[:, klass] > config.RESULT_SCORE_THRESH)[0]
if ids.size == 0:
continue
probs_k = probs[ids, klass].flatten()
boxes_k = boxes[ids,:]
selected_ids = nms_func(boxes_k[:,[1,0,3,2]], probs_k)
selected_boxes = boxes_k[selected_ids, :].copy()
ret.append(DetectionResult(klass, selected_boxes, probs_k[selected_ids]))
if len(ret):
newret = []
all_scores = np.hstack([x.scores for x in ret])
if len(all_scores) > config.RESULTS_PER_IM:
score_thresh = np.sort(all_scores)[-config.RESULTS_PER_IM]
for klass, boxes, scores in ret:
keep_ids = np.where(scores >= score_thresh)[0]
if len(keep_ids):
newret.append(DetectionResult(
klass, boxes[keep_ids,:], scores[keep_ids]))
ret = newret
return ret
def detect_one_image(img, model_func):
"""
Run detection on one image, using the TF callable.
This function should handle the preprocessing internally.
Args:
img: an image
model_func: a callable from TF model, takes [image] and returns (probs, boxes)
Returns:
[DetectionResult]
"""
resizer = CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE)
resized_img = resizer.augment(img)
scale = (resized_img.shape[0] * 1.0 / img.shape[0] + resized_img.shape[1] * 1.0 / img.shape[1]) / 2
fg_probs, fg_boxes = model_func([resized_img])
fg_boxes = fg_boxes / scale
fg_boxes = clip_boxes(fg_boxes, img.shape[:2])
return nms_fastrcnn_results(fg_boxes, fg_probs)
def eval_on_dataflow(df, detect_func):
"""
Args:
df: a DataFlow which produces (image, image_id)
detect_func: a callable, takes [image] and returns a dict
Returns:
list of dict, to be dumped to COCO json format
"""
df.reset_state()
all_results = []
with tqdm.tqdm(total=df.size(), **get_tqdm_kwargs()) as pbar:
for img, img_id in df.get_data():
results = detect_func(img)
for classid, boxes, scores in results:
cat_id = COCOMeta.class_id_to_category_id[classid]
boxes[:,2] -= boxes[:,0]
boxes[:,3] -= boxes[:,1]
for box, score in zip(boxes, scores):
all_results.append({
'image_id': img_id,
'category_id': cat_id,
'bbox': list(map(lambda x: float(round(x, 1)), box)),
'score': float(round(score, 2)),
})
pbar.update(1)
return all_results
# https://github.com/pdollar/coco/blob/master/PythonAPI/pycocoEvalDemo.ipynb
def print_evaluation_scores(json_file):
assert config.BASEDIR and os.path.isdir(config.BASEDIR)
annofile = os.path.join(
config.BASEDIR, 'annotations',
'instances_{}.json'.format(config.VAL_DATASET))
coco = COCO(annofile)
cocoDt = coco.loadRes(json_file)
imgIds = sorted(coco.getImgIds())
cocoEval = COCOeval(coco, cocoDt, 'bbox')
cocoEval.params.imgIds = imgIds
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
if __name__ == '__main__':
ds = get_eval_dataflow('/home/yuxinwu/data/COCO/')
print("Size: ", ds.size())
TestDataSpeed(ds, 1000).start()
This diff is collapsed.
This diff is collapsed.
# Some third-party helper functions
+ generate_anchors.py: copied from [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/rpn/generate_anchors.py).
+ box_ops.py: modified from [TF object detection API](https://github.com/tensorflow/models/blob/master/object_detection/core/box_list_ops.py).
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: box_ops.py
import tensorflow as tf
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.tfutils import get_default_sess_config
from tensorpack.utils.argtools import memoized
"""
This file is modified from
https://github.com/tensorflow/models/blob/master/object_detection/core/box_list_ops.py
"""
@under_name_scope()
def area(boxes):
"""
Args:
boxes: nx4 floatbox
Returns:
n
"""
x_min, y_min, x_max, y_max = tf.split(boxes, 4, axis=1)
return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])
@under_name_scope()
def pairwise_intersection(boxlist1, boxlist2):
"""Compute pairwise intersection areas between boxes.
Args:
boxlist1: Nx4 floatbox
boxlist2: Mx4
Returns:
a tensor with shape [N, M] representing pairwise intersections
"""
x_min1, y_min1, x_max1, y_max1 = tf.split(boxlist1, 4, axis=1)
x_min2, y_min2, x_max2, y_max2 = tf.split(boxlist2, 4, axis=1)
all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2))
all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2))
intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin)
all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2))
all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2))
intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin)
return intersect_heights * intersect_widths
@under_name_scope()
def pairwise_iou(boxlist1, boxlist2):
"""Computes pairwise intersection-over-union between box collections.
Args:
boxlist1: Nx4 floatbox
boxlist2: Mx4
Returns:
a tensor with shape [N, M] representing pairwise iou scores.
"""
intersections = pairwise_intersection(boxlist1, boxlist2)
areas1 = area(boxlist1)
areas2 = area(boxlist2)
unions = (
tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections)
return tf.where(
tf.equal(intersections, 0.0),
tf.zeros_like(intersections), tf.truediv(intersections, unions))
@memoized
def get_iou_callable():
"""
Get a pairwise box iou callable.
"""
with tf.device('/cpu:0'):
A = tf.placeholder(tf.float32, shape=[None, 4])
B = tf.placeholder(tf.float32, shape=[None, 4])
iou = pairwise_iou(A, B)
sess = tf.Session(config=get_default_sess_config())
return sess.make_callable(iou, [A, B])
# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/rpn/generate_anchors.py
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick and Sean Bell
# --------------------------------------------------------
from six.moves import range
import numpy as np
# Verify that we compute the same anchors as Shaoqing's matlab implementation:
#
# >> load output/rpn_cachedir/faster_rcnn_VOC2007_ZF_stage1_rpn/anchors.mat
# >> anchors
#
# anchors =
#
# -83 -39 100 56
# -175 -87 192 104
# -359 -183 376 200
# -55 -55 72 72
# -119 -119 136 136
# -247 -247 264 264
# -35 -79 52 96
# -79 -167 96 184
# -167 -343 184 360
#array([[ -83., -39., 100., 56.],
# [-175., -87., 192., 104.],
# [-359., -183., 376., 200.],
# [ -55., -55., 72., 72.],
# [-119., -119., 136., 136.],
# [-247., -247., 264., 264.],
# [ -35., -79., 52., 96.],
# [ -79., -167., 96., 184.],
# [-167., -343., 184., 360.]])
def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
scales=2**np.arange(3, 6)):
"""
Generate anchor (reference) windows by enumerating aspect ratios X
scales wrt a reference (0, 0, 15, 15) window.
"""
base_anchor = np.array([1, 1, base_size, base_size], dtype='float32') - 1
ratio_anchors = _ratio_enum(base_anchor, ratios)
anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales)
for i in range(ratio_anchors.shape[0])])
return anchors
def _whctrs(anchor):
"""
Return width, height, x center, and y center for an anchor (window).
"""
w = anchor[2] - anchor[0] + 1
h = anchor[3] - anchor[1] + 1
x_ctr = anchor[0] + 0.5 * (w - 1)
y_ctr = anchor[1] + 0.5 * (h - 1)
return w, h, x_ctr, y_ctr
def _mkanchors(ws, hs, x_ctr, y_ctr):
"""
Given a vector of widths (ws) and heights (hs) around a center
(x_ctr, y_ctr), output a set of anchors (windows).
"""
ws = ws[:, np.newaxis]
hs = hs[:, np.newaxis]
anchors = np.hstack((x_ctr - 0.5 * (ws - 1),
y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1),
y_ctr + 0.5 * (hs - 1)))
return anchors
def _ratio_enum(anchor, ratios):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
size = w * h
size_ratios = size / ratios
ws = np.round(np.sqrt(size_ratios))
hs = np.round(ws * ratios)
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
def _scale_enum(anchor, scales):
"""
Enumerate a set of anchors for each scale wrt an anchor.
"""
w, h, x_ctr, y_ctr = _whctrs(anchor)
ws = w * scales
hs = h * scales
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
return anchors
if __name__ == '__main__':
#import time
#t = time.time()
#a = generate_anchors()
#print(time.time() - t)
#print(a)
#from IPython import embed; embed()
print(generate_anchors(
16, scales=np.asarray((2, 4, 8, 16, 32), 'float32'),
ratios=[0.5,1,2]))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: viz.py
from six.moves import zip
import numpy as np
from tensorpack.utils import viz
from coco import COCOMeta
from utils.box_ops import get_iou_callable
def draw_annotation(img, boxes, klass, is_crowd=None):
labels = []
assert len(boxes) == len(klass)
if is_crowd is not None:
assert len(boxes) == len(is_crowd)
for cls, crd in zip(klass, is_crowd):
clsname = COCOMeta.class_names[cls]
if crd == 1:
clsname += ';Crowd'
labels.append(clsname)
else:
for cls in klass:
labels.append(COCOMeta.class_names[cls])
img = viz.draw_boxes(img, boxes, labels)
return img
def draw_proposal_recall(img, proposals, proposal_scores, gt_boxes):
"""
Draw top3 proposals for each gt.
Args:
proposals: NPx4
proposal_scores: NP
gt_boxes: NG
"""
bbox_iou_float = get_iou_callable()
box_ious = bbox_iou_float(gt_boxes, proposals) #ng x np
box_ious_argsort = np.argsort(-box_ious, axis=1)
good_proposals_ind = box_ious_argsort[:,:3] # for each gt, find 3 best proposals
good_proposals_ind = np.unique(good_proposals_ind.ravel())
proposals = proposals[good_proposals_ind,:]
tags = list(map(str, proposal_scores[good_proposals_ind]))
img = viz.draw_boxes(img, proposals, tags)
return img, good_proposals_ind
def draw_predictions(img, boxes, scores):
"""
Args:
boxes: kx4
scores: kxC
"""
if len(boxes) == 0:
return img
labels = scores.argmax(axis=1)
scores = scores.max(axis=1)
tags = ["{},{:.2f}".format(COCOMeta.class_names[lb], score) for lb, score in zip(labels, scores)]
return viz.draw_boxes(img, boxes, tags)
def draw_final_outputs(img, results):
"""
Args:
results: [DetectionResult]
"""
all_boxes = []
all_tags = []
for class_id, boxes, scores in results:
all_boxes.extend(boxes)
all_tags.extend(
["{},{:.2f}".format(COCOMeta.class_names[class_id], sc) for sc in scores])
all_boxes = np.asarray(all_boxes)
if all_boxes.shape[0] == 0:
return img
return viz.draw_boxes(img, all_boxes, all_tags)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment