Commit 35599506 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] fallback to tf.gather_nd for older TF versions

parent 2aa760b1
...@@ -151,8 +151,13 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits): ...@@ -151,8 +151,13 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
num_fg = tf.size(fg_inds, out_type=tf.int64) num_fg = tf.size(fg_inds, out_type=tf.int64)
empty_fg = tf.equal(num_fg, 0) empty_fg = tf.equal(num_fg, 0)
if int(fg_box_logits.shape[1]) > 1: if int(fg_box_logits.shape[1]) > 1:
fg_box_logits = tf.batch_gather(fg_box_logits, tf.expand_dims(fg_labels, axis=1)) if get_tf_version_tuple() >= (1, 14):
fg_box_logits = tf.reshape(fg_box_logits, [-1, 4]) fg_labels = tf.expand_dims(fg_labels, axis=1) # nfg x 1
fg_box_logits = tf.gather(fg_box_logits, fg_labels, batch_dims=1)
else:
indices = tf.stack([tf.range(num_fg), fg_labels], axis=1) # nfgx2
fg_box_logits = tf.gather_nd(fg_box_logits, indices)
fg_box_logits = tf.reshape(fg_box_logits, [-1, 4]) # nfg x 4
with tf.name_scope('label_metrics'), tf.device('/cpu:0'): with tf.name_scope('label_metrics'), tf.device('/cpu:0'):
prediction = tf.argmax(label_logits, axis=1, name='label_prediction') prediction = tf.argmax(label_logits, axis=1, name='label_prediction')
......
...@@ -20,7 +20,14 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks): ...@@ -20,7 +20,14 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
fg_labels: #fg, in 1~#class, int64 fg_labels: #fg, in 1~#class, int64
fg_target_masks: #fgxhxw, float32 fg_target_masks: #fgxhxw, float32
""" """
mask_logits = tf.batch_gather(mask_logits, tf.reshape(fg_labels, [-1, 1]) - 1) if get_tf_version_tuple() >= (1, 14):
mask_logits = tf.gather(
mask_logits, tf.reshape(fg_labels - 1, [-1, 1]), batch_dims=1)
else:
indices = tf.stack([tf.range(tf.size(fg_labels, out_type=tf.int64)),
fg_labels - 1], axis=1) # #fgx2
mask_logits = tf.gather_nd(mask_logits, indices) # #fg x h x w
mask_logits = tf.squeeze(mask_logits, axis=1) mask_logits = tf.squeeze(mask_logits, axis=1)
mask_probs = tf.sigmoid(mask_logits) mask_probs = tf.sigmoid(mask_logits)
......
...@@ -76,12 +76,12 @@ class SessionUpdate(object): ...@@ -76,12 +76,12 @@ class SessionUpdate(object):
if np.prod(varshape) != np.prod(value.shape): if np.prod(varshape) != np.prod(value.shape):
if ignore_mismatch: if ignore_mismatch:
logger.warn( logger.warn(
"Cannot load a tensor of shape {} into the variable '{}' whose shape is {}.".format( "Cannot load an array of shape {} into variable '{}' whose shape is {}.".format(
value.shape, name, varshape)) value.shape, name, varshape))
return None return None
else: else:
raise ValueError( raise ValueError(
"Trying to load a tensor of shape {} into the variable '{}' whose shape is {}.".format( "Trying to load an array of shape {} into variable '{}' whose shape is {}.".format(
value.shape, name, varshape)) value.shape, name, varshape))
# TODO only allow reshape when shape different by empty axis # TODO only allow reshape when shape different by empty axis
logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format( logger.warn("The tensor is reshaped from {} to {} when assigned to '{}'".format(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment