Commit acfb57c2 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] make rpn_head and fastrcnn_head layers

parent f3c50d39
...@@ -8,21 +8,22 @@ from tensorpack.tfutils import get_current_tower_context ...@@ -8,21 +8,22 @@ from tensorpack.tfutils import get_current_tower_context
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.argscope import argscope from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import under_name_scope from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.models import Conv2D, FullyConnected, GlobalAvgPooling from tensorpack.models import (
Conv2D, FullyConnected, GlobalAvgPooling, layer_register)
from utils.box_ops import pairwise_iou from utils.box_ops import pairwise_iou
import config import config
@layer_register(log_shape=True)
def rpn_head(featuremap, channel, num_anchors): def rpn_head(featuremap, channel, num_anchors):
""" """
Returns: Returns:
label_logits: fHxfWxNA label_logits: fHxfWxNA
box_logits: fHxfWxNAx4 box_logits: fHxfWxNAx4
""" """
with tf.variable_scope('rpn'), \ with argscope(Conv2D, data_format='NCHW',
argscope(Conv2D, data_format='NCHW', W_init=tf.random_normal_initializer(stddev=0.01)):
W_init=tf.random_normal_initializer(stddev=0.01)):
hidden = Conv2D('conv0', featuremap, channel, 3, nl=tf.nn.relu) hidden = Conv2D('conv0', featuremap, channel, 3, nl=tf.nn.relu)
label_logits = Conv2D('class', hidden, num_anchors, 1) label_logits = Conv2D('class', hidden, num_anchors, 1)
...@@ -371,6 +372,7 @@ def roi_align(featuremap, boxes, output_shape): ...@@ -371,6 +372,7 @@ def roi_align(featuremap, boxes, output_shape):
return ret return ret
@layer_register(log_shape=True)
def fastrcnn_head(feature, num_classes): def fastrcnn_head(feature, num_classes):
""" """
Args: Args:
...@@ -381,15 +383,14 @@ def fastrcnn_head(feature, num_classes): ...@@ -381,15 +383,14 @@ def fastrcnn_head(feature, num_classes):
cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4) cls_logits (Nxnum_class), reg_logits (Nx num_class-1 x 4)
""" """
feature = GlobalAvgPooling('gap', feature, data_format='NCHW') feature = GlobalAvgPooling('gap', feature, data_format='NCHW')
with tf.variable_scope('fastrcnn'): classification = FullyConnected(
classification = FullyConnected( 'class', feature, num_classes,
'class', feature, num_classes, W_init=tf.random_normal_initializer(stddev=0.01))
W_init=tf.random_normal_initializer(stddev=0.01)) box_regression = FullyConnected(
box_regression = FullyConnected( 'box', feature, (num_classes - 1) * 4,
'box', feature, (num_classes - 1) * 4, W_init=tf.random_normal_initializer(stddev=0.001))
W_init=tf.random_normal_initializer(stddev=0.001)) box_regression = tf.reshape(box_regression, (-1, num_classes - 1, 4))
box_regression = tf.reshape(box_regression, (-1, num_classes - 1, 4)) return classification, box_regression
return classification, box_regression
@under_name_scope() @under_name_scope()
......
...@@ -88,7 +88,7 @@ class Model(ModelDesc): ...@@ -88,7 +88,7 @@ class Model(ModelDesc):
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3]) featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3])
rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024, config.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR)
rpn_label_loss, rpn_box_loss = rpn_losses( rpn_label_loss, rpn_box_loss = rpn_losses(
anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits)
...@@ -99,13 +99,19 @@ class Model(ModelDesc): ...@@ -99,13 +99,19 @@ class Model(ModelDesc):
tf.shape(image)[2:]) tf.shape(image)[2:])
if is_training: if is_training:
# sample proposal boxes in training
rcnn_sampled_boxes, rcnn_encoded_boxes, rcnn_labels = sample_fast_rcnn_targets( rcnn_sampled_boxes, rcnn_encoded_boxes, rcnn_labels = sample_fast_rcnn_targets(
proposal_boxes, gt_boxes, gt_labels) proposal_boxes, gt_boxes, gt_labels)
boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE) boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) else:
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc # use all proposal boxes in inference
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS) boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE)
roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc
fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head('fastrcnn', feature_fastrcnn, config.NUM_CLASS)
if is_training:
fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses(
rcnn_labels, rcnn_encoded_boxes, fastrcnn_label_logits, fastrcnn_box_logits) rcnn_labels, rcnn_encoded_boxes, fastrcnn_label_logits, fastrcnn_box_logits)
...@@ -121,11 +127,8 @@ class Model(ModelDesc): ...@@ -121,11 +127,8 @@ class Model(ModelDesc):
for k in self.cost, wd_cost: for k in self.cost, wd_cost:
add_moving_summary(k) add_moving_summary(k)
else: else:
roi_resized = roi_align(featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14) label_probs = tf.nn.softmax(fastrcnn_label_logits, name='fastrcnn_all_probs') # NP,
feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc labels = tf.argmax(fastrcnn_label_logits, axis=1)
label_logits, fastrcnn_box_logits = fastrcnn_head(feature_fastrcnn, config.NUM_CLASS)
label_probs = tf.nn.softmax(label_logits, name='fastrcnn_all_probs') # NP,
labels = tf.argmax(label_logits, axis=1)
fg_ind, fg_box_logits = fastrcnn_predict_boxes(labels, fastrcnn_box_logits) fg_ind, fg_box_logits = fastrcnn_predict_boxes(labels, fastrcnn_box_logits)
fg_label_probs = tf.gather(label_probs, fg_ind, name='fastrcnn_fg_probs') fg_label_probs = tf.gather(label_probs, fg_ind, name='fastrcnn_fg_probs')
fg_boxes = tf.gather(proposal_boxes, fg_ind) fg_boxes = tf.gather(proposal_boxes, fg_ind)
......
...@@ -39,7 +39,10 @@ class GPUUtilizationTracker(Callback): ...@@ -39,7 +39,10 @@ class GPUUtilizationTracker(Callback):
"Will monitor all visible GPUs!") "Will monitor all visible GPUs!")
self._devices = list(map(str, range(get_nr_gpu()))) self._devices = list(map(str, range(get_nr_gpu())))
else: else:
self._devices = env.split(',') if len(env):
self._devices = env.split(',')
else:
self._devices = []
else: else:
self._devices = list(map(str, devices)) self._devices = list(map(str, devices))
assert len(self._devices), "[GPUUtilizationTracker] No GPU device given!" assert len(self._devices), "[GPUUtilizationTracker] No GPU device given!"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment