Commit 35599506 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] fallback to tf.gather_nd for older TF versions

parent 2aa760b1
......@@ -151,8 +151,13 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
num_fg = tf.size(fg_inds, out_type=tf.int64)
empty_fg = tf.equal(num_fg, 0)
if int(fg_box_logits.shape[1]) > 1:
fg_box_logits = tf.batch_gather(fg_box_logits, tf.expand_dims(fg_labels, axis=1))
fg_box_logits = tf.reshape(fg_box_logits, [-1, 4])
if get_tf_version_tuple() >= (1, 14):
fg_labels = tf.expand_dims(fg_labels, axis=1) # nfg x 1
fg_box_logits = tf.gather(fg_box_logits, fg_labels, batch_dims=1)
else:
indices = tf.stack([tf.range(num_fg), fg_labels], axis=1) # nfgx2
fg_box_logits = tf.gather_nd(fg_box_logits, indices)
fg_box_logits = tf.reshape(fg_box_logits, [-1, 4]) # nfg x 4
with tf.name_scope('label_metrics'), tf.device('/cpu:0'):
prediction = tf.argmax(label_logits, axis=1, name='label_prediction')
......
......@@ -20,7 +20,14 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
fg_labels: #fg, in 1~#class, int64
fg_target_masks: #fgxhxw, float32
"""
mask_logits = tf.batch_gather(mask_logits, tf.reshape(fg_labels, [-1, 1]) - 1)
if get_tf_version_tuple() >= (1, 14):
mask_logits = tf.gather(
mask_logits, tf.reshape(fg_labels - 1, [-1, 1]), batch_dims=1)
else:
indices = tf.stack([tf.range(tf.size(fg_labels, out_type=tf.int64)),
fg_labels - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fg x h x w
mask_logits = tf.squeeze(mask_logits, axis=1)
mask_probs = tf.sigmoid(mask_logits)
......
......@@ -76,12 +76,12 @@ class SessionUpdate(object):
if np.prod(varshape) != np.prod(value.shape):
if ignore_mismatch:
logger.warn(
"Cannot load a tensor of shape {} into the variable '{}' whose shape is {}.".format(
"Cannot load an array of shape {} into variable '{}' whose shape is {}.".format(
value.shape, name, varshape))
return None
else:
raise ValueError(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}.".format(
"Trying to load an array of shape {} into variable '{}' whose shape is {}.".format(
value.shape, name, varshape))
# TODO only allow reshape when shape different by empty axis
logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment