Commit c712e8dd authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] support visualization for FPN

parent 37530d96
...@@ -162,6 +162,7 @@ def multilevel_rpn_losses( ...@@ -162,6 +162,7 @@ def multilevel_rpn_losses(
return total_label_loss, total_box_loss return total_label_loss, total_box_loss
@under_name_scope()
def generate_fpn_proposals( def generate_fpn_proposals(
multilevel_anchors, multilevel_label_logits, multilevel_anchors, multilevel_label_logits,
multilevel_box_logits, image_shape2d): multilevel_box_logits, image_shape2d):
...@@ -186,7 +187,7 @@ def generate_fpn_proposals( ...@@ -186,7 +187,7 @@ def generate_fpn_proposals(
if cfg.FPN.PROPOSAL_MODE == 'Level': if cfg.FPN.PROPOSAL_MODE == 'Level':
fpn_nms_topk = cfg.RPN.TRAIN_PER_LEVEL_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PER_LEVEL_NMS_TOPK fpn_nms_topk = cfg.RPN.TRAIN_PER_LEVEL_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PER_LEVEL_NMS_TOPK
for lvl in range(num_lvl): for lvl in range(num_lvl):
with tf.name_scope('FPNProposal_Lvl{}'.format(lvl + 2)): with tf.name_scope('Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl] anchors = multilevel_anchors[lvl]
pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl]) pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl])
...@@ -204,7 +205,7 @@ def generate_fpn_proposals( ...@@ -204,7 +205,7 @@ def generate_fpn_proposals(
proposal_boxes = tf.gather(proposal_boxes, topk_indices) proposal_boxes = tf.gather(proposal_boxes, topk_indices)
else: else:
for lvl in range(num_lvl): for lvl in range(num_lvl):
with tf.name_scope('FPNProposal_Lvl{}'.format(lvl + 2)): with tf.name_scope('Lvl{}'.format(lvl + 2)):
anchors = multilevel_anchors[lvl] anchors = multilevel_anchors[lvl]
pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl]) pred_boxes_decoded = anchors.decode_logits(multilevel_box_logits[lvl])
all_boxes.append(tf.reshape(pred_boxes_decoded, [-1, 4])) all_boxes.append(tf.reshape(pred_boxes_decoded, [-1, 4]))
...@@ -216,4 +217,6 @@ def generate_fpn_proposals( ...@@ -216,4 +217,6 @@ def generate_fpn_proposals(
cfg.RPN.TRAIN_PRE_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PRE_NMS_TOPK, cfg.RPN.TRAIN_PRE_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_PRE_NMS_TOPK,
cfg.RPN.TRAIN_POST_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_POST_NMS_TOPK) cfg.RPN.TRAIN_POST_NMS_TOPK if ctx.is_training else cfg.RPN.TEST_POST_NMS_TOPK)
return proposal_boxes, proposal_scores tf.sigmoid(proposal_scores, name='probs') # for visualization
return tf.identity(proposal_boxes, name='boxes'), \
tf.identity(proposal_scores, name='scores')
...@@ -148,7 +148,7 @@ def generate_rpn_proposals(boxes, scores, img_shape, ...@@ -148,7 +148,7 @@ def generate_rpn_proposals(boxes, scores, img_shape,
iou_threshold=cfg.RPN.PROPOSAL_NMS_THRESH) iou_threshold=cfg.RPN.PROPOSAL_NMS_THRESH)
topk_valid_boxes = tf.reshape(topk_valid_boxes_x1y1x2y2, (-1, 4)) topk_valid_boxes = tf.reshape(topk_valid_boxes_x1y1x2y2, (-1, 4))
final_boxes = tf.gather(topk_valid_boxes, nms_indices) proposal_boxes = tf.gather(topk_valid_boxes, nms_indices)
final_scores = tf.gather(topk_valid_scores, nms_indices) proposal_scores = tf.gather(topk_valid_scores, nms_indices)
tf.sigmoid(final_scores, name='probs') # for visualization tf.sigmoid(proposal_scores, name='probs') # for visualization
return tf.stop_gradient(final_boxes, name='boxes'), tf.stop_gradient(final_scores, name='scores') return tf.stop_gradient(proposal_boxes, name='boxes'), tf.stop_gradient(proposal_scores, name='scores')
...@@ -396,7 +396,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -396,7 +396,7 @@ class ResNetFPNModel(DetectionModel):
tf.sigmoid(final_mask_logits, name='final_masks') tf.sigmoid(final_mask_logits, name='final_masks')
def visualize(model_path, nr_visualize=50, output_dir='output'): def visualize(model, model_path, nr_visualize=100, output_dir='output'):
""" """
Visualize some intermediate results (proposals, raw predictions) inside the pipeline. Visualize some intermediate results (proposals, raw predictions) inside the pipeline.
Does not support FPN. Does not support FPN.
...@@ -405,12 +405,12 @@ def visualize(model_path, nr_visualize=50, output_dir='output'): ...@@ -405,12 +405,12 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
df.reset_state() df.reset_state()
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
model=ResNetC4Model(), model=model,
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
input_names=['image', 'gt_boxes', 'gt_labels'], input_names=['image', 'gt_boxes', 'gt_labels'],
output_names=[ output_names=[
'generate_rpn_proposals/boxes', 'generate_{}_proposals/boxes'.format('fpn' if cfg.MODE_FPN else 'rpn'),
'generate_rpn_proposals/probs', 'generate_{}_proposals/probs'.format('fpn' if cfg.MODE_FPN else 'rpn'),
'fastrcnn_all_probs', 'fastrcnn_all_probs',
'final_boxes', 'final_boxes',
'final_probs', 'final_probs',
...@@ -422,7 +422,11 @@ def visualize(model_path, nr_visualize=50, output_dir='output'): ...@@ -422,7 +422,11 @@ def visualize(model_path, nr_visualize=50, output_dir='output'):
utils.fs.mkdir_p(output_dir) utils.fs.mkdir_p(output_dir)
with tqdm.tqdm(total=nr_visualize) as pbar: with tqdm.tqdm(total=nr_visualize) as pbar:
for idx, dp in itertools.islice(enumerate(df.get_data()), nr_visualize): for idx, dp in itertools.islice(enumerate(df.get_data()), nr_visualize):
img, _, _, gt_boxes, gt_labels = dp img = dp[0]
if cfg.MODE_MASK:
gt_boxes, gt_labels, gt_masks = dp[-3:]
else:
gt_boxes, gt_labels = dp[-2:]
rpn_boxes, rpn_scores, all_probs, \ rpn_boxes, rpn_scores, all_probs, \
final_boxes, final_probs, final_labels = pred(img, gt_boxes, gt_labels) final_boxes, final_probs, final_labels = pred(img, gt_boxes, gt_labels)
...@@ -530,8 +534,7 @@ if __name__ == '__main__': ...@@ -530,8 +534,7 @@ if __name__ == '__main__':
cfg.TEST.RESULT_SCORE_THRESH = cfg.TEST.RESULT_SCORE_THRESH_VIS cfg.TEST.RESULT_SCORE_THRESH = cfg.TEST.RESULT_SCORE_THRESH_VIS
if args.visualize: if args.visualize:
assert not cfg.MODE_FPN, "FPN visualize is not supported!" visualize(MODEL, args.load)
visualize(args.load)
else: else:
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
model=MODEL, model=MODEL,
......
...@@ -26,6 +26,7 @@ class ModelSaver(Callback): ...@@ -26,6 +26,7 @@ class ModelSaver(Callback):
Args: Args:
max_to_keep (int): the same as in ``tf.train.Saver``. max_to_keep (int): the same as in ``tf.train.Saver``.
keep_checkpoint_every_n_hours (float): the same as in ``tf.train.Saver``. keep_checkpoint_every_n_hours (float): the same as in ``tf.train.Saver``.
Note that "keep" does not mean "create", but means "don't delete".
checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``. checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``.
var_collections (str or list of str): collection of the variables (or list of collections) to save. var_collections (str or list of str): collection of the variables (or list of collections) to save.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment