Commit 99d99e7d authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] split some methods from build_graph

parent c878fb1f
......@@ -39,7 +39,7 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity):
data_format = get_arg_scope()['Conv2D']['data_format']
n_in = l.get_shape().as_list()[1 if data_format == 'NCHW' else 3]
if n_in != n_out: # change dimension when channel is not the same
if stride == 2 and 'group3' not in tf.get_variable_scope().name:
if stride == 2:
l = l[:, :, :-1, :-1]
return Conv2D('convshortcut', l, n_out, 1,
stride=stride, padding='VALID', nl=nl)
......@@ -53,7 +53,7 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity):
def resnet_bottleneck(l, ch_out, stride):
l, shortcut = l, l
l = Conv2D('conv1', l, ch_out, 1, nl=BNReLU)
if stride == 2 and 'group3' not in tf.get_variable_scope().name:
if stride == 2:
l = tf.pad(l, [[0, 0], [0, 0], [0, 1], [0, 1]])
l = Conv2D('conv2', l, ch_out, 3, stride=2, nl=BNReLU, padding='VALID')
else:
......
......@@ -60,11 +60,16 @@ class Model(ModelDesc):
InputDesc(tf.int64, (None,), 'gt_labels'),
]
def _build_graph(self, inputs):
is_training = get_current_tower_context().is_training
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
def _preprocess(self, image):
image = tf.expand_dims(image, 0)
image = image_preprocess(image, bgr=True)
return tf.transpose(image, [0, 3, 1, 2])
def _get_anchors(self, image):
"""
Returns:
FSxFSxNAx4 anchors,
"""
# FSxFSxNAx4 (FS=MAX_SIZE//ANCHOR_STRIDE)
with tf.name_scope('anchors'):
all_anchors = tf.constant(get_all_anchors(), name='all_anchors', dtype=tf.float32)
......@@ -73,11 +78,15 @@ class Model(ModelDesc):
tf.shape(image)[1] // config.ANCHOR_STRIDE,
tf.shape(image)[2] // config.ANCHOR_STRIDE,
-1, -1]), name='fm_anchors')
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
return fm_anchors
image = image_preprocess(image, bgr=True)
image = tf.transpose(image, [0, 3, 1, 2])
def _build_graph(self, inputs):
is_training = get_current_tower_context().is_training
image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs
image = self._preprocess(image)
fm_anchors = self._get_anchors(image)
anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors)
# resnet50
featuremap = pretrained_resnet_conv4(image, [3, 4, 6])
rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024, config.NR_ANCHOR)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment