Commit b3f6e31d authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] move GAP to fastrcnn_head

parent 29c81dd8
......@@ -91,11 +91,10 @@ def pretrained_resnet_conv4(image, num_blocks):
return l
def resnet_conv5_gap(image, num_block):
with argscope([Conv2D, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \
def resnet_conv5(image, num_block):
with argscope([Conv2D, BatchNorm], data_format='NCHW'), \
argscope(Conv2D, nl=tf.identity, use_bias=False), \
argscope(BatchNorm, use_local_stat=False):
# 14x14:
l = resnet_group(image, 'group3', resnet_bottleneck, 512, num_block, stride=2)
l = GlobalAvgPooling('gap', l)
return l
......@@ -244,10 +244,10 @@ def get_train_dataflow(add_mask=False):
# one image-sized binary mask per box
masks = []
for box, polys in zip(boxes, segmentation):
for polys in segmentation:
polys = [aug.augment_coords(p, params) for p in polys]
masks.append(segmentation_to_mask(polys, im.shape[0], im.shape[1]))
masks = np.asarray(masks, dtype='uint8')
masks = np.asarray(masks, dtype='uint8') # values in {0, 1}
ret.append(masks)
# from viz import draw_annotation, draw_mask
......
......@@ -8,7 +8,7 @@ from tensorpack.tfutils import get_current_tower_context
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import Conv2D, FullyConnected
from tensorpack.models import Conv2D, FullyConnected, GlobalAvgPooling
from utils.box_ops import pairwise_iou
import config
......@@ -375,12 +375,13 @@ def roi_align(featuremap, boxes, output_shape):
def fastrcnn_head(feature, num_classes):
"""
Args:
feature (NxCx1x1):
feature (NxCx7x7):
num_classes(int): num_category + 1
Returns:
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
"""
feature = GlobalAvgPooling('gap', feature, data_format='NCHW')
with tf.variable_scope('fastrcnn'):
classification = FullyConnected(
'class', feature, num_classes,
......
......@@ -25,7 +25,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from coco import COCODetection
from basemodel import (
image_preprocess, pretrained_resnet_conv4, resnet_conv5_gap)
image_preprocess, pretrained_resnet_conv4, resnet_conv5)
from model import (
rpn_head, rpn_losses,
decode_bbox_target, encode_bbox_target,
......@@ -103,7 +103,7 @@ class Model(ModelDesc):
proposal_boxes, gt_boxes, gt_labels)
boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
feature_fastrcnn = resnet_conv5_gap(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
......@@ -122,7 +122,7 @@ class Model(ModelDesc):
add_moving_summary(k)
else:
roi_resized = roi_align(featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
feature_fastrcnn = resnet_conv5_gap(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc
label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS)
label_probs = tf.nn.softmax(label_logits, name='fastrcnn_all_probs') # NP,
labels = tf.argmax(label_logits, axis=1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment