Commit 6fc4378c authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] fix bug in gradient propagation

parent 7556cc1e
...@@ -292,7 +292,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -292,7 +292,7 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
ret_labels = tf.concat( ret_labels = tf.concat(
[tf.gather(gt_labels, fg_inds_wrt_gt), [tf.gather(gt_labels, fg_inds_wrt_gt),
tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels') tf.zeros_like(bg_inds, dtype=tf.int64)], axis=0, name='sampled_labels')
return ret_boxes, tf.stop_gradient(ret_labels), fg_inds_wrt_gt return tf.stop_gradient(ret_boxes), tf.stop_gradient(ret_labels), fg_inds_wrt_gt
@under_name_scope() @under_name_scope()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment