Commit bbf29a18 authored by Yuxin Wu's avatar Yuxin Wu

assert class ids not out of bounds (#1336)

parent 17cb3554
...@@ -48,6 +48,8 @@ def print_class_histogram(roidbs): ...@@ -48,6 +48,8 @@ def print_class_histogram(roidbs):
# filter crowd? # filter crowd?
gt_inds = np.where((entry["class"] > 0) & (entry["is_crowd"] == 0))[0] gt_inds = np.where((entry["class"] > 0) & (entry["is_crowd"] == 0))[0]
gt_classes = entry["class"][gt_inds] gt_classes = entry["class"][gt_inds]
if len(gt_classes):
assert gt_classes.max() <= len(class_names) - 1
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0] gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = list(itertools.chain(*[[class_names[i + 1], v] for i, v in enumerate(gt_hist[1:])])) data = list(itertools.chain(*[[class_names[i + 1], v] for i, v in enumerate(gt_hist[1:])]))
COL = min(6, len(data)) COL = min(6, len(data))
...@@ -97,6 +99,7 @@ class TrainingDataPreprocessor: ...@@ -97,6 +99,7 @@ class TrainingDataPreprocessor:
points = tfms.apply_coords(points) points = tfms.apply_coords(points)
boxes = point8_to_box(points) boxes = point8_to_box(points)
if len(boxes): if len(boxes):
assert klass.max() <= cfg.DATA.NUM_CATEGORY, "Invalid category {}!".format(klass.max())
assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!" assert np.min(np_area(boxes)) > 0, "Some boxes have zero area!"
ret = {"image": im} ret = {"image": im}
......
...@@ -23,12 +23,12 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks): ...@@ -23,12 +23,12 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
if get_tf_version_tuple() >= (1, 14): if get_tf_version_tuple() >= (1, 14):
mask_logits = tf.gather( mask_logits = tf.gather(
mask_logits, tf.reshape(fg_labels - 1, [-1, 1]), batch_dims=1) mask_logits, tf.reshape(fg_labels - 1, [-1, 1]), batch_dims=1)
mask_logits = tf.squeeze(mask_logits, axis=1)
else: else:
indices = tf.stack([tf.range(tf.size(fg_labels, out_type=tf.int64)), indices = tf.stack([tf.range(tf.size(fg_labels, out_type=tf.int64)),
fg_labels - 1], axis=1) # #fgx2 fg_labels - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fg x h x w mask_logits = tf.gather_nd(mask_logits, indices) # #fg x h x w
mask_logits = tf.squeeze(mask_logits, axis=1)
mask_probs = tf.sigmoid(mask_logits) mask_probs = tf.sigmoid(mask_logits)
# add some training visualizations to tensorboard # add some training visualizations to tensorboard
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment