Commit 16581e74 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] refactor; split functions to model_box.py

parent 77cee2b2
......@@ -4,8 +4,9 @@ This is a minimal implementation that simply contains these files:
+ coco.py: load COCO data
+ data.py: prepare data for training
+ common.py: common data preparation utilities
+ basemodel.py: implement resnet
+ model.py: implement RPN/Faster-RCNN/RPN/Mask-RCNN
+ basemodel.py: implement backbones
+ model_box.py: implement box-related symbolic functions
+ model.py: implement RPN/Faster-RCNN/FPN/Mask-RCNN
+ train.py: main training script
+ utils/: third-party helper functions
+ eval.py: evaluation utilities
......
......@@ -6,7 +6,7 @@ This example provides a minimal (only 1.6k lines) and faithful implementation of
+ [Mask R-CNN](https://arxiv.org/abs/1703.06870)
## Dependencies
+ Python 3; TensorFlow >= 1.4.0 (>=1.6.0 recommended due to a TF bug);
+ Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug);
+ [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV.
+ Pre-trained [ImageNet ResNet model](http://models.tensorpack.com/ResNet/) from tensorpack model zoo.
+ COCO data. It needs to have the following directory structure:
......
......@@ -69,6 +69,10 @@ FASTRCNN_FG_RATIO = 0.25 # fg ratio in a ROI batch
# modeling -------------------------
FPN_NUM_CHANNEL = 256
# conv head and fc head are only used in FPN.
# For C4 models, the head is C5
FPN_FASTRCNN_HEAD_FUNC = 'fastrcnn_2fc_head' # choices: fastrcnn_2fc_head, fastrcnn_4conv1fc_head
FASTRCNN_CONV_HEAD_DIM = 256
FASTRCNN_FC_HEAD_DIM = 1024
MASKRCNN_HEAD_DIM = 256
......
......@@ -14,22 +14,10 @@ from tensorpack.models import (
from utils.box_ops import pairwise_iou
from utils.box_ops import area as tf_area
from model_box import roi_align, clip_boxes
import config
@under_name_scope()
def clip_boxes(boxes, window, name=None):
"""
Args:
boxes: nx4, xyxy
window: [h, w]
"""
boxes = tf.maximum(boxes, 0.0)
m = tf.tile(tf.reverse(window, [0]), [2]) # (4,)
boxes = tf.minimum(boxes, tf.to_float(m), name=name)
return boxes
@layer_register(log_shape=True)
@auto_reuse_variable_scope
def rpn_head(featuremap, channel, num_anchors):
......@@ -119,62 +107,6 @@ def rpn_losses(anchor_labels, anchor_boxes, label_logits, box_logits):
return label_loss, box_loss
@under_name_scope()
def decode_bbox_target(box_predictions, anchors):
"""
Args:
box_predictions: (..., 4), logits
anchors: (..., 4), floatbox. Must have the same shape
Returns:
box_decoded: (..., 4), float32. With the same shape.
"""
orig_shape = tf.shape(anchors)
box_pred_txtytwth = tf.reshape(box_predictions, (-1, 2, 2))
box_pred_txty, box_pred_twth = tf.split(box_pred_txtytwth, 2, axis=1)
# each is (...)x1x2
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
waha = anchors_x2y2 - anchors_x1y1
xaya = (anchors_x2y2 + anchors_x1y1) * 0.5
wbhb = tf.exp(tf.minimum(
box_pred_twth, config.BBOX_DECODE_CLIP)) * waha
xbyb = box_pred_txty * waha + xaya
x1y1 = xbyb - wbhb * 0.5
x2y2 = xbyb + wbhb * 0.5 # (...)x1x2
out = tf.concat([x1y1, x2y2], axis=-2)
return tf.reshape(out, orig_shape)
@under_name_scope()
def encode_bbox_target(boxes, anchors):
"""
Args:
boxes: (..., 4), float32
anchors: (..., 4), float32
Returns:
box_encoded: (..., 4), float32 with the same shape.
"""
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
waha = anchors_x2y2 - anchors_x1y1
xaya = (anchors_x2y2 + anchors_x1y1) * 0.5
boxes_x1y1x2y2 = tf.reshape(boxes, (-1, 2, 2))
boxes_x1y1, boxes_x2y2 = tf.split(boxes_x1y1x2y2, 2, axis=1)
wbhb = boxes_x2y2 - boxes_x1y1
xbyb = (boxes_x2y2 + boxes_x1y1) * 0.5
# Note that here not all boxes are valid. Some may be zero
txty = (xbyb - xaya) / waha
twth = tf.log(wbhb / waha) # may contain -inf for invalid boxes
encoded = tf.concat([txty, twth], axis=1) # (-1x2x2)
return tf.reshape(encoded, tf.shape(boxes))
@under_name_scope()
def generate_rpn_proposals(boxes, scores, img_shape,
pre_nms_topk, post_nms_topk=None):
......@@ -312,98 +244,6 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
tf.stop_gradient(fg_inds_wrt_gt)
@under_name_scope()
def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
"""
Aligned version of tf.image.crop_and_resize, following our definition of floating point boxes.
Args:
image: NCHW
boxes: nx4, x1y1x2y2
box_ind: (n,)
crop_size (int):
Returns:
n,C,size,size
"""
assert isinstance(crop_size, int), crop_size
# TF's crop_and_resize produces zeros on border
if pad_border:
# this can be quite slow
image = tf.pad(image, [[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC')
boxes = boxes + 1
@under_name_scope()
def transform_fpcoor_for_tf(boxes, image_shape, crop_shape):
"""
The way tf.image.crop_and_resize works (with normalized box):
Initial point (the value of output[0]): x0_box * (W_img - 1)
Spacing: w_box * (W_img - 1) / (W_crop - 1)
Use the above grid to bilinear sample.
However, what we want is (with fpcoor box):
Spacing: w_box / W_crop
Initial point: x0_box + spacing/2 - 0.5
(-0.5 because bilinear sample assumes floating point coordinate (0.0, 0.0) is the same as pixel value (0, 0))
This function transform fpcoor boxes to a format to be used by tf.image.crop_and_resize
Returns:
y1x1y2x2
"""
x0, y0, x1, y1 = tf.split(boxes, 4, axis=1)
spacing_w = (x1 - x0) / tf.to_float(crop_shape[1])
spacing_h = (y1 - y0) / tf.to_float(crop_shape[0])
nx0 = (x0 + spacing_w / 2 - 0.5) / tf.to_float(image_shape[1] - 1)
ny0 = (y0 + spacing_h / 2 - 0.5) / tf.to_float(image_shape[0] - 1)
nw = spacing_w * tf.to_float(crop_shape[1] - 1) / tf.to_float(image_shape[1] - 1)
nh = spacing_h * tf.to_float(crop_shape[0] - 1) / tf.to_float(image_shape[0] - 1)
return tf.concat([ny0, nx0, ny0 + nh, nx0 + nw], axis=1)
# Expand bbox to a minium size of 1
# boxes_x1y1, boxes_x2y2 = tf.split(boxes, 2, axis=1)
# boxes_wh = boxes_x2y2 - boxes_x1y1
# boxes_center = tf.reshape((boxes_x2y2 + boxes_x1y1) * 0.5, [-1, 2])
# boxes_newwh = tf.maximum(boxes_wh, 1.)
# boxes_x1y1new = boxes_center - boxes_newwh * 0.5
# boxes_x2y2new = boxes_center + boxes_newwh * 0.5
# boxes = tf.concat([boxes_x1y1new, boxes_x2y2new], axis=1)
image_shape = tf.shape(image)[2:]
boxes = transform_fpcoor_for_tf(boxes, image_shape, [crop_size, crop_size])
image = tf.transpose(image, [0, 2, 3, 1]) # nhwc
ret = tf.image.crop_and_resize(
image, boxes, tf.to_int32(box_ind),
crop_size=[crop_size, crop_size])
ret = tf.transpose(ret, [0, 3, 1, 2]) # ncss
return ret
@under_name_scope()
def roi_align(featuremap, boxes, resolution):
"""
Args:
featuremap: 1xCxHxW
boxes: Nx4 floatbox
resolution: output spatial resolution
Returns:
NxCx res x res
"""
boxes = tf.stop_gradient(boxes) # TODO
# sample 4 locations per roi bin
ret = crop_and_resize(
featuremap, boxes,
tf.zeros([tf.shape(boxes)[0]], dtype=tf.int32),
resolution * 2)
ret = tf.nn.avg_pool(ret, [1, 1, 2, 2], [1, 1, 2, 2], padding='SAME', data_format='NCHW')
return ret
@layer_register(log_shape=True)
def fastrcnn_outputs(feature, num_classes):
"""
......@@ -436,11 +276,37 @@ def fastrcnn_2fc_head(feature, num_classes):
"""
dim = config.FASTRCNN_FC_HEAD_DIM
init = tf.variance_scaling_initializer()
hidden = FullyConnected('fc6', feature, dim, kernel_initializer=init, nl=tf.nn.relu)
hidden = FullyConnected('fc7', hidden, dim, kernel_initializer=init, nl=tf.nn.relu)
hidden = FullyConnected('fc6', feature, dim, kernel_initializer=init, activation=tf.nn.relu)
hidden = FullyConnected('fc7', hidden, dim, kernel_initializer=init, activation=tf.nn.relu)
return fastrcnn_outputs('outputs', hidden, num_classes)
@layer_register(log_shape=True)
def fastrcnn_Xconv1fc_head(feature, num_classes, num_convs):
"""
Args:
feature (any shape):
num_classes(int): num_category + 1
num_convs (int): number of conv layers
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
l = feature
with argscope(Conv2D, data_format='channels_first',
kernel_initializer=tf.variance_scaling_initializer(
scale=2.0, mode='fan_out', distribution='normal')):
for k in range(num_convs):
l = Conv2D('conv{}'.format(k), l, config.FASTRCNN_CONV_HEAD_DIM, 3, activation=tf.nn.relu)
l = FullyConnected('fc', l, config.FASTRCNN_FC_HEAD_DIM,
kernel_initializer=tf.variance_scaling_initializer(), activation=tf.nn.relu)
return fastrcnn_outputs('outputs', l, num_classes)
def fastrcnn_4conv1fc_head(*args, **kwargs):
return fastrcnn_Xconv1fc_head(*args, num_convs=4, **kwargs)
@under_name_scope()
def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
"""
......@@ -625,7 +491,7 @@ def fpn_model(features):
# return x
with argscope(Conv2D, data_format='channels_first',
nl=tf.identity, use_bias=True,
activation=tf.identity, use_bias=True,
kernel_initializer=tf.variance_scaling_initializer(scale=1.)):
lat_2345 = [Conv2D('lateral_1x1_c{}'.format(i + 2), c, num_channel, 1)
for i, c in enumerate(features)]
......@@ -703,32 +569,3 @@ def multilevel_roi_align(features, rcnn_boxes, resolution):
level_id_invert_perm = tf.invert_permutation(level_id_perm)
all_rois = tf.gather(all_rois, level_id_invert_perm)
return all_rois
if __name__ == '__main__':
"""
Demonstrate what's wrong with tf.image.crop_and_resize:
"""
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
# want to crop 2x2 out of a 5x5 image, and resize to 4x4
image = np.arange(25).astype('float32').reshape(5, 5)
boxes = np.asarray([[1, 1, 3, 3]], dtype='float32')
target = 4
print(crop_and_resize(
image[None, None, :, :], boxes, [0], target)[0][0])
"""
Expected values:
4.5 5 5.5 6
7 7.5 8 8.5
9.5 10 10.5 11
12 12.5 13 13.5
You cannot easily get the above results with tf.image.crop_and_resize.
Try out yourself here:
"""
print(tf.image.crop_and_resize(
image[None, :, :, None],
np.asarray([[1, 1, 2, 2]]) / 4.0, [0], [target, target])[0][:, :, 0])
# -*- coding: utf-8 -*-
# File: model_box.py
import tensorflow as tf
from tensorpack.tfutils.scope_utils import under_name_scope
import config
@under_name_scope()
def clip_boxes(boxes, window, name=None):
"""
Args:
boxes: nx4, xyxy
window: [h, w]
"""
boxes = tf.maximum(boxes, 0.0)
m = tf.tile(tf.reverse(window, [0]), [2]) # (4,)
boxes = tf.minimum(boxes, tf.to_float(m), name=name)
return boxes
@under_name_scope()
def decode_bbox_target(box_predictions, anchors):
"""
Args:
box_predictions: (..., 4), logits
anchors: (..., 4), floatbox. Must have the same shape
Returns:
box_decoded: (..., 4), float32. With the same shape.
"""
orig_shape = tf.shape(anchors)
box_pred_txtytwth = tf.reshape(box_predictions, (-1, 2, 2))
box_pred_txty, box_pred_twth = tf.split(box_pred_txtytwth, 2, axis=1)
# each is (...)x1x2
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
waha = anchors_x2y2 - anchors_x1y1
xaya = (anchors_x2y2 + anchors_x1y1) * 0.5
wbhb = tf.exp(tf.minimum(
box_pred_twth, config.BBOX_DECODE_CLIP)) * waha
xbyb = box_pred_txty * waha + xaya
x1y1 = xbyb - wbhb * 0.5
x2y2 = xbyb + wbhb * 0.5 # (...)x1x2
out = tf.concat([x1y1, x2y2], axis=-2)
return tf.reshape(out, orig_shape)
@under_name_scope()
def encode_bbox_target(boxes, anchors):
"""
Args:
boxes: (..., 4), float32
anchors: (..., 4), float32
Returns:
box_encoded: (..., 4), float32 with the same shape.
"""
anchors_x1y1x2y2 = tf.reshape(anchors, (-1, 2, 2))
anchors_x1y1, anchors_x2y2 = tf.split(anchors_x1y1x2y2, 2, axis=1)
waha = anchors_x2y2 - anchors_x1y1
xaya = (anchors_x2y2 + anchors_x1y1) * 0.5
boxes_x1y1x2y2 = tf.reshape(boxes, (-1, 2, 2))
boxes_x1y1, boxes_x2y2 = tf.split(boxes_x1y1x2y2, 2, axis=1)
wbhb = boxes_x2y2 - boxes_x1y1
xbyb = (boxes_x2y2 + boxes_x1y1) * 0.5
# Note that here not all boxes are valid. Some may be zero
txty = (xbyb - xaya) / waha
twth = tf.log(wbhb / waha) # may contain -inf for invalid boxes
encoded = tf.concat([txty, twth], axis=1) # (-1x2x2)
return tf.reshape(encoded, tf.shape(boxes))
@under_name_scope()
def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
"""
Aligned version of tf.image.crop_and_resize, following our definition of floating point boxes.
Args:
image: NCHW
boxes: nx4, x1y1x2y2
box_ind: (n,)
crop_size (int):
Returns:
n,C,size,size
"""
assert isinstance(crop_size, int), crop_size
# TF's crop_and_resize produces zeros on border
if pad_border:
# this can be quite slow
image = tf.pad(image, [[0, 0], [0, 0], [1, 1], [1, 1]], mode='SYMMETRIC')
boxes = boxes + 1
@under_name_scope()
def transform_fpcoor_for_tf(boxes, image_shape, crop_shape):
"""
The way tf.image.crop_and_resize works (with normalized box):
Initial point (the value of output[0]): x0_box * (W_img - 1)
Spacing: w_box * (W_img - 1) / (W_crop - 1)
Use the above grid to bilinear sample.
However, what we want is (with fpcoor box):
Spacing: w_box / W_crop
Initial point: x0_box + spacing/2 - 0.5
(-0.5 because bilinear sample assumes floating point coordinate (0.0, 0.0) is the same as pixel value (0, 0))
This function transform fpcoor boxes to a format to be used by tf.image.crop_and_resize
Returns:
y1x1y2x2
"""
x0, y0, x1, y1 = tf.split(boxes, 4, axis=1)
spacing_w = (x1 - x0) / tf.to_float(crop_shape[1])
spacing_h = (y1 - y0) / tf.to_float(crop_shape[0])
nx0 = (x0 + spacing_w / 2 - 0.5) / tf.to_float(image_shape[1] - 1)
ny0 = (y0 + spacing_h / 2 - 0.5) / tf.to_float(image_shape[0] - 1)
nw = spacing_w * tf.to_float(crop_shape[1] - 1) / tf.to_float(image_shape[1] - 1)
nh = spacing_h * tf.to_float(crop_shape[0] - 1) / tf.to_float(image_shape[0] - 1)
return tf.concat([ny0, nx0, ny0 + nh, nx0 + nw], axis=1)
# Expand bbox to a minium size of 1
# boxes_x1y1, boxes_x2y2 = tf.split(boxes, 2, axis=1)
# boxes_wh = boxes_x2y2 - boxes_x1y1
# boxes_center = tf.reshape((boxes_x2y2 + boxes_x1y1) * 0.5, [-1, 2])
# boxes_newwh = tf.maximum(boxes_wh, 1.)
# boxes_x1y1new = boxes_center - boxes_newwh * 0.5
# boxes_x2y2new = boxes_center + boxes_newwh * 0.5
# boxes = tf.concat([boxes_x1y1new, boxes_x2y2new], axis=1)
image_shape = tf.shape(image)[2:]
boxes = transform_fpcoor_for_tf(boxes, image_shape, [crop_size, crop_size])
image = tf.transpose(image, [0, 2, 3, 1]) # nhwc
ret = tf.image.crop_and_resize(
image, boxes, tf.to_int32(box_ind),
crop_size=[crop_size, crop_size])
ret = tf.transpose(ret, [0, 3, 1, 2]) # ncss
return ret
@under_name_scope()
def roi_align(featuremap, boxes, resolution):
"""
Args:
featuremap: 1xCxHxW
boxes: Nx4 floatbox
resolution: output spatial resolution
Returns:
NxCx res x res
"""
boxes = tf.stop_gradient(boxes) # TODO
# sample 4 locations per roi bin
ret = crop_and_resize(
featuremap, boxes,
tf.zeros([tf.shape(boxes)[0]], dtype=tf.int32),
resolution * 2)
ret = tf.nn.avg_pool(ret, [1, 1, 2, 2], [1, 1, 2, 2], padding='SAME', data_format='NCHW')
return ret
if __name__ == '__main__':
"""
Demonstrate what's wrong with tf.image.crop_and_resize:
"""
import tensorflow.contrib.eager as tfe
import numpy as np
tfe.enable_eager_execution()
# want to crop 2x2 out of a 5x5 image, and resize to 4x4
image = np.arange(25).astype('float32').reshape(5, 5)
boxes = np.asarray([[1, 1, 3, 3]], dtype='float32')
target = 4
print(crop_and_resize(
image[None, None, :, :], boxes, [0], target)[0][0])
"""
Expected values:
4.5 5 5.5 6
7 7.5 8 8.5
9.5 10 10.5 11
12 12.5 13 13.5
You cannot easily get the above results with tf.image.crop_and_resize.
Try out yourself here:
"""
print(tf.image.crop_and_resize(
image[None, :, :, None],
np.asarray([[1, 1, 2, 2]]) / 4.0, [0], [target, target])[0][:, :, 0])
......@@ -32,13 +32,15 @@ from coco import COCODetection
from basemodel import (
image_preprocess, resnet_c4_backbone, resnet_conv5,
resnet_fpn_backbone)
import model
from model import (
clip_boxes, decode_bbox_target, encode_bbox_target, crop_and_resize,
rpn_head, rpn_losses,
generate_rpn_proposals, sample_fast_rcnn_targets, roi_align,
generate_rpn_proposals, sample_fast_rcnn_targets,
fastrcnn_outputs, fastrcnn_losses, fastrcnn_predictions,
maskrcnn_upXconv_head, maskrcnn_loss,
fpn_model, fastrcnn_2fc_head, multilevel_roi_align)
fpn_model, multilevel_roi_align)
from model_box import (
clip_boxes, decode_bbox_target, encode_bbox_target, crop_and_resize, roi_align)
from data import (
get_train_dataflow, get_eval_dataflow,
get_all_anchors, get_all_anchors_fpn)
......@@ -51,22 +53,6 @@ from eval import (
import config
def get_model_output_names():
ret = ['final_boxes', 'final_probs', 'final_labels']
if config.MODE_MASK:
ret.append('final_masks')
return ret
def get_model():
if config.MODE_FPN:
if get_tf_version_number() < 1.6:
logger.warn("FPN has chances to crash in TF<1.6, due to a TF issue.")
return ResNetFPNModel()
else:
return ResNetC4Model()
class DetectionModel(ModelDesc):
def preprocess(self, image):
image = tf.expand_dims(image, 0)
......@@ -159,6 +145,19 @@ class DetectionModel(ModelDesc):
final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels')
return final_boxes, final_labels
def get_inference_tensor_names(self):
"""
Returns two lists of tensor names to be used to create an inference callable.
Returns:
[str]: input names
[str]: output names
"""
out = ['final_boxes', 'final_probs', 'final_labels']
if config.MODE_MASK:
out.append('final_masks')
return ['image'], out
class ResNetC4Model(DetectionModel):
def inputs(self):
......@@ -210,25 +209,10 @@ class ResNetC4Model(DetectionModel):
boxes_on_featuremap = rcnn_boxes * (1.0 / config.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
# which was fixed in TF 1.6
def ff_true():
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7
# Keep C5 feature to be shared with mask branch
feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first')
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, config.NUM_CLASS)
# Return C5 feature to be shared with mask branch
return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits
def ff_false():
ncls = config.NUM_CLASS
return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4])
if get_tf_version_number() >= 1.6:
feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = ff_true()
else:
logger.warn("This example may drop support for TF < 1.6 soon.")
feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond(
tf.size(boxes_on_featuremap) > 0, ff_true, ff_false)
if is_training:
# rpn loss
......@@ -281,18 +265,13 @@ class ResNetC4Model(DetectionModel):
image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
if config.MODE_MASK:
# HACK to work around https://github.com/tensorflow/tensorflow/issues/14657
def f1():
roi_resized = roi_align(featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
feature_maskrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1])
mask_logits = maskrcnn_upXconv_head(
'maskrcnn', feature_maskrcnn, config.NUM_CLASS, 0) # #result x #cat x 14x14
indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1)
final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14
return tf.sigmoid(final_mask_logits)
final_masks = tf.cond(tf.size(final_labels) > 0, f1, lambda: tf.zeros([0, 14, 14]))
tf.identity(final_masks, name='final_masks')
tf.sigmoid(final_mask_logits, name='final_masks')
class ResNetFPNModel(DetectionModel):
......@@ -385,7 +364,8 @@ class ResNetFPNModel(DetectionModel):
roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head(
fastrcnn_head_func = getattr(model, config.FPN_FASTRCNN_HEAD_FUNC)
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head_func(
'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS)
if is_training:
......@@ -518,9 +498,11 @@ def predict(pred_func, input_file):
class EvalCallback(Callback):
def __init__(self, in_names, out_names):
self._in_names, self._out_names = in_names, out_names
def _setup_graph(self):
self.pred = self.trainer.get_predictor(
['image'], get_model_output_names())
self.pred = self.trainer.get_predictor(self._in_names, self._out_names)
self.df = get_eval_dataflow()
def _before_train(self):
......@@ -550,6 +532,9 @@ class EvalCallback(Callback):
def init_config():
"""
Initialize config for training.
"""
if config.TRAINER == 'horovod':
ngpu = hvd.size()
else:
......@@ -569,17 +554,23 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load', help='load a model for evaluation or training')
parser.add_argument('--logdir', help='log directory', default='train_log/maskrcnn')
parser.add_argument('--config', help="A list of KEY=VALUE to overwrite those defined in config.py",
nargs='+')
parser.add_argument('--visualize', action='store_true', help='visualize intermediate results')
parser.add_argument('--evaluate', help="Run evaluation on COCO. "
"This argument is the path to the output json evaluation file")
parser.add_argument('--predict', help="Run prediction on a given image. "
"This argument is the path to the input image file")
parser.add_argument('--config', help="A list of key=value to overwrite those defined in config.py",
nargs='+')
if get_tf_version_number() < 1.6:
# https://github.com/tensorflow/tensorflow/issues/14657
logger.warn("TF<1.6 has a bug which may lead to crash in FasterRCNN training if you're unlucky.")
args = parser.parse_args()
write_config_from_args(args.config)
MODEL = ResNetFPNModel() if config.MODE_FPN else ResNetC4Model()
if args.visualize or args.evaluate or args.predict:
# autotune is too slow for inference
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
......@@ -595,12 +586,12 @@ if __name__ == '__main__':
visualize(args.load)
else:
pred = OfflinePredictor(PredictConfig(
model=get_model(),
model=MODEL,
session_init=get_model_loader(args.load),
input_names=['image'],
output_names=get_model_output_names()))
input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1]))
if args.evaluate:
assert args.evaluate.endswith('.json')
assert args.evaluate.endswith('.json'), args.evaluate
offline_evaluate(pred, args.evaluate)
elif args.predict:
COCODetection(config.BASEDIR, 'val2014') # Only to load the class names into caches
......@@ -640,7 +631,7 @@ if __name__ == '__main__':
ScheduledHyperParamSetter(
'learning_rate', warmup_schedule, interp='linear', step_based=True),
ScheduledHyperParamSetter('learning_rate', lr_schedule),
EvalCallback(),
EvalCallback(*MODEL.get_inference_tensor_names()),
PeakMemoryTracker(),
EstimatedTimeLeft(),
SessionRunTimeout(60000).set_chief_only(True), # 1 minute timeout
......@@ -649,7 +640,7 @@ if __name__ == '__main__':
callbacks.append(GPUUtilizationTracker())
cfg = TrainConfig(
model=get_model(),
model=MODEL,
data=QueueInput(get_train_dataflow()),
callbacks=callbacks,
steps_per_epoch=stepnum,
......
......@@ -369,7 +369,7 @@ class HorovodTrainer(SingleCostTrainer):
op = hvd.broadcast_global_variables(0)
cb = RunOp(
op, run_before=True,
run_as_trigger=False, verbose=True)
run_as_trigger=True, verbose=True)
return [cb]
@HIDE_DOC
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment