Commit 994a150b authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] no need to lower bound fg_ratio

parent fa8af3d8
...@@ -40,7 +40,7 @@ FASTRCNN_BATCH_PER_IM = 64 ...@@ -40,7 +40,7 @@ FASTRCNN_BATCH_PER_IM = 64
FASTRCNN_BBOX_REG_WEIGHTS = np.array([10, 10, 5, 5], dtype='float32') FASTRCNN_BBOX_REG_WEIGHTS = np.array([10, 10, 5, 5], dtype='float32')
FASTRCNN_FG_THRESH = 0.5 FASTRCNN_FG_THRESH = 0.5
# keep fg ratio in a batch in this range # keep fg ratio in a batch in this range
FASTRCNN_FG_RATIO = (0.1, 0.25) FASTRCNN_FG_RATIO = 0.25
# testing ----------------------- # testing -----------------------
TEST_PRE_NMS_TOPK = 6000 TEST_PRE_NMS_TOPK = 6000
......
...@@ -261,16 +261,13 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels): ...@@ -261,16 +261,13 @@ def sample_fast_rcnn_targets(boxes, gt_boxes, gt_labels):
tf.int64)], 0) tf.int64)], 0)
num_fg = tf.size(fg_inds) num_fg = tf.size(fg_inds)
num_fg = tf.minimum(int( num_fg = tf.minimum(int(
config.FASTRCNN_BATCH_PER_IM * config.FASTRCNN_FG_RATIO[1]), config.FASTRCNN_BATCH_PER_IM * config.FASTRCNN_FG_RATIO),
num_fg, name='num_fg') num_fg, name='num_fg')
fg_inds = tf.slice(tf.random_shuffle(fg_inds), [0], [num_fg]) fg_inds = tf.slice(tf.random_shuffle(fg_inds), [0], [num_fg])
bg_inds = tf.where(tf.logical_not(fg_mask))[:, 0] bg_inds = tf.where(tf.logical_not(fg_mask))[:, 0]
num_bg = tf.size(bg_inds) num_bg = tf.size(bg_inds)
num_bg = tf.minimum(config.FASTRCNN_BATCH_PER_IM - num_fg, num_bg) num_bg = tf.minimum(config.FASTRCNN_BATCH_PER_IM - num_fg, num_bg, name='num_bg')
num_bg = tf.minimum(
num_bg,
num_fg * int(1.0 / config.FASTRCNN_FG_RATIO[0]), name='num_bg') # don't include too many bg
bg_inds = tf.slice(tf.random_shuffle(bg_inds), [0], [num_bg]) bg_inds = tf.slice(tf.random_shuffle(bg_inds), [0], [num_bg])
add_moving_summary(num_fg, num_bg) add_moving_summary(num_fg, num_bg)
......
...@@ -252,6 +252,7 @@ if __name__ == '__main__': ...@@ -252,6 +252,7 @@ if __name__ == '__main__':
if args.evaluate is not None: if args.evaluate is not None:
assert args.evaluate.endswith('.json') assert args.evaluate.endswith('.json')
assert args.load assert args.load
# autotune is too slow for inference
os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0' os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0'
offline_evaluate(args.load, args.evaluate) offline_evaluate(args.load, args.evaluate)
sys.exit() sys.exit()
...@@ -283,7 +284,7 @@ if __name__ == '__main__': ...@@ -283,7 +284,7 @@ if __name__ == '__main__':
], ],
steps_per_epoch=stepnum, steps_per_epoch=stepnum,
max_epoch=205000 // stepnum, max_epoch=205000 // stepnum,
session_init=get_model_loader(args.load), session_init=get_model_loader(args.load) if args.load else None,
nr_tower=nr_gpu nr_tower=nr_gpu
) )
SyncMultiGPUTrainerReplicated(cfg, gpu_prefetch=False).train() SyncMultiGPUTrainerReplicated(cfg, gpu_prefetch=False).train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment