Commit bbf29a18 authored by Yuxin Wu's avatar Yuxin Wu

assert class ids not out of bounds (#1336)

parent 17cb3554
......@@ -48,6 +48,8 @@ def print_class_histogram(roidbs):
# filter crowd?
gt_inds = np.where((entry["class"] > 0) & (entry["is_crowd"] == 0))[0]
gt_classes = entry["class"][gt_inds]
if len(gt_classes):
assert gt_classes.max() <= len(class_names) - 1
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = list(itertools.chain(*[[class_names[i + 1], v] for i, v in enumerate(gt_hist[1:])]))
COL = min(6, len(data))
......@@ -97,6 +99,7 @@ class TrainingDataPreprocessor:
points = tfms.apply_coords(points)
boxes = point8_to_box(points)
if len(boxes):
assert klass.max() <= cfg.DATA.NUM_CATEGORY, "Invalid category {}!".format(klass.max())
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
ret = {"image": im}
......
......@@ -23,12 +23,12 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
if get_tf_version_tuple() >= (1, 14):
mask_logits = tf.gather(
mask_logits, tf.reshape(fg_labels - 1, [-1, 1]), batch_dims=1)
mask_logits = tf.squeeze(mask_logits, axis=1)
else:
indices = tf.stack([tf.range(tf.size(fg_labels, out_type=tf.int64)),
fg_labels - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fg x h x w
mask_logits = tf.squeeze(mask_logits, axis=1)
mask_probs = tf.sigmoid(mask_logits)
# add some training visualizations to tensorboard
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment