Commit 1f881fcf authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] also support deeper resnet

parent 99d99e7d
...@@ -91,11 +91,11 @@ def pretrained_resnet_conv4(image, num_blocks): ...@@ -91,11 +91,11 @@ def pretrained_resnet_conv4(image, num_blocks):
return l return l
def resnet_conv5(image): def resnet_conv5_gap(image, num_block):
with argscope([Conv2D, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \ with argscope([Conv2D, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \
argscope(Conv2D, nl=tf.identity, use_bias=False), \ argscope(Conv2D, nl=tf.identity, use_bias=False), \
argscope(BatchNorm, use_local_stat=False): argscope(BatchNorm, use_local_stat=False):
# 14x14: # 14x14:
l = resnet_group(image, 'group3', resnet_bottleneck, 512, 3, stride=2) l = resnet_group(image, 'group3', resnet_bottleneck, 512, num_block, stride=2)
l = GlobalAvgPooling('gap', l) l = GlobalAvgPooling('gap', l)
return l return l
...@@ -10,6 +10,9 @@ TRAIN_DATASET = ['train2014', 'valminusminival2014'] ...@@ -10,6 +10,9 @@ TRAIN_DATASET = ['train2014', 'valminusminival2014']
VAL_DATASET = 'minival2014' # only support evaluation on one dataset VAL_DATASET = 'minival2014' # only support evaluation on one dataset
NUM_CLASS = 81 NUM_CLASS = 81
# basemodel ----------------------
RESNET_NUM_BLOCK = [3, 4, 6, 3] # resnet50
# preprocessing -------------------- # preprocessing --------------------
SHORT_EDGE_SIZE = 600 SHORT_EDGE_SIZE = 600
MAX_SIZE = 1024 MAX_SIZE = 1024
......
...@@ -25,7 +25,7 @@ from tensorpack.utils.gpu import get_nr_gpu ...@@ -25,7 +25,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from coco import COCODetection from coco import COCODetection
from basemodel import ( from basemodel import (
image_preprocess, pretrained_resnet_conv4, resnet_conv5) image_preprocess, pretrained_resnet_conv4, resnet_conv5_gap)
from model import ( from model import (
rpn_head, rpn_losses, rpn_head, rpn_losses,
decode_bbox_target, encode_bbox_target, decode_bbox_target, encode_bbox_target,
...@@ -87,8 +87,7 @@ class Model(ModelDesc): ...@@ -87,8 +87,7 @@ class Model(ModelDesc):
fm_anchors = self._get_anchors(image) fm_anchors = self._get_anchors(image)
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
# resnet50 featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3])
featuremap = pretrained_resnet_conv4(image, [3, 4, 6])
rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024, config.NR_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024, config.NR_ANCHOR)
rpn_label_loss, rpn_box_loss = rpn_losses( rpn_label_loss, rpn_box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)
...@@ -104,7 +103,7 @@ class Model(ModelDesc): ...@@ -104,7 +103,7 @@ class Model(ModelDesc):
proposal_boxes, gt_boxes, gt_labels) proposal_boxes, gt_boxes, gt_labels)
boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE) boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
feature_fastrcnn = resnet_conv5(roi_resized) # nxc feature_fastrcnn = resnet_conv5_gap(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
...@@ -123,7 +122,7 @@ class Model(ModelDesc): ...@@ -123,7 +122,7 @@ class Model(ModelDesc):
add_moving_summary(k) add_moving_summary(k)
else: else:
roi_resized = roi_align(featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14) roi_resized = roi_align(featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
feature_fastrcnn = resnet_conv5(roi_resized) # nxc feature_fastrcnn = resnet_conv5_gap(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc
label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS) label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS)
label_probs = tf.nn.softmax(label_logits, name='fastrcnn_all_probs') # NP, label_probs = tf.nn.softmax(label_logits, name='fastrcnn_all_probs') # NP,
labels = tf.argmax(label_logits, axis=1) labels = tf.argmax(label_logits, axis=1)
......
...@@ -273,4 +273,5 @@ def TryResumeTraining(): ...@@ -273,4 +273,5 @@ def TryResumeTraining():
path = os.path.join(logger.get_logger_dir(), 'checkpoint') path = os.path.join(logger.get_logger_dir(), 'checkpoint')
if not tf.gfile.Exists(path): if not tf.gfile.Exists(path):
return JustCurrentSession() return JustCurrentSession()
logger.info("Found checkpoint at {}.".format(path))
return SaverRestore(path) return SaverRestore(path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment