Commit 1f881fcf authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] also support deeper resnet

parent 99d99e7d
......@@ -91,11 +91,11 @@ def pretrained_resnet_conv4(image, num_blocks):
return l
def resnet_conv5(image):
def resnet_conv5_gap(image, num_block):
with argscope([Conv2D, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \
argscope(Conv2D, nl=tf.identity, use_bias=False), \
argscope(BatchNorm, use_local_stat=False):
# 14x14:
l = resnet_group(image, 'group3', resnet_bottleneck, 512, 3, stride=2)
l = resnet_group(image, 'group3', resnet_bottleneck, 512, num_block, stride=2)
l = GlobalAvgPooling('gap', l)
return l
......@@ -10,6 +10,9 @@ TRAIN_DATASET = ['train2014', 'valminusminival2014']
VAL_DATASET = 'minival2014' # only support evaluation on one dataset
NUM_CLASS = 81
# basemodel ----------------------
RESNET_NUM_BLOCK = [3, 4, 6, 3] # resnet50
# preprocessing --------------------
SHORT_EDGE_SIZE = 600
MAX_SIZE = 1024
......
......@@ -25,7 +25,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from coco import COCODetection
from basemodel import (
image_preprocess, pretrained_resnet_conv4, resnet_conv5)
image_preprocess, pretrained_resnet_conv4, resnet_conv5_gap)
from model import (
rpn_head, rpn_losses,
decode_bbox_target, encode_bbox_target,
......@@ -87,8 +87,7 @@ class Model(ModelDesc):
fm_anchors = self._get_anchors(image)
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
# resnet50
featuremap = pretrained_resnet_conv4(image, [3, 4, 6])
featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024, config.NR_ANCHOR)
rpn_label_loss, rpn_box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)
......@@ -104,7 +103,7 @@ class Model(ModelDesc):
proposal_boxes, gt_boxes, gt_labels)
boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
feature_fastrcnn = resnet_conv5(roi_resized) # nxc
feature_fastrcnn = resnet_conv5_gap(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS)
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
......@@ -123,7 +122,7 @@ class Model(ModelDesc):
add_moving_summary(k)
else:
roi_resized = roi_align(featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14)
feature_fastrcnn = resnet_conv5(roi_resized) # nxc
feature_fastrcnn = resnet_conv5_gap(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc
label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS)
label_probs = tf.nn.softmax(label_logits, name='fastrcnn_all_probs') # NP,
labels = tf.argmax(label_logits, axis=1)
......
......@@ -273,4 +273,5 @@ def TryResumeTraining():
path = os.path.join(logger.get_logger_dir(), 'checkpoint')
if not tf.gfile.Exists(path):
return JustCurrentSession()
logger.info("Found checkpoint at {}.".format(path))
return SaverRestore(path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment