Commit 192de99a authored by Yuxin Wu's avatar Yuxin Wu

fix bug in FasterRCNN SparseTensor & parallel map data

parent 8132387f
......@@ -208,21 +208,21 @@ def fastrcnn_predictions(boxes, scores):
# NMS within each class
selection = tf.image.non_max_suppression(
box, prob, cfg.TEST.RESULTS_PER_IM, cfg.TEST.FRCNN_NMS_THRESH)
selection = tf.to_int32(tf.gather(ids, selection))
selection = tf.gather(ids, selection)
# sort available in TF>1.4.0
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0]
if get_tf_version_tuple() >= (1, 12):
mask = tf.sparse.SparseTensor(indices=sorted_selection,
mask = tf.sparse.SparseTensor(indices=tf.expand_dims(sorted_selection, 1),
values=tf.ones_like(sorted_selection, dtype=tf.bool),
dense_shape=tf.shape(prob))
dense_shape=tf.shape(prob, out_type=tf.int64))
mask = tf.sparse.to_dense(mask, default_value=False)
else:
# deprecated by TF
# this function is deprecated by TF
mask = tf.sparse_to_dense(
sparse_indices=sorted_selection,
output_shape=tf.shape(prob),
output_shape=tf.shape(prob, out_type=tf.int64),
sparse_values=True,
default_value=False)
return mask
......
......@@ -25,8 +25,6 @@ __all__ = ['ThreadedMapData', 'MultiThreadMapData',
class _ParallelMapData(ProxyDataFlow):
def __init__(self, ds, buffer_size, strict=False):
if not strict:
ds = RepeatedData(ds, -1)
super(_ParallelMapData, self).__init__(ds)
assert buffer_size > 0, buffer_size
self._buffer_size = buffer_size
......@@ -35,7 +33,11 @@ class _ParallelMapData(ProxyDataFlow):
def reset_state(self):
super(_ParallelMapData, self).reset_state()
self._iter = self.ds.__iter__()
if not self._strict:
ds = RepeatedData(self.ds, -1)
else:
ds = self.ds
self._iter = ds.__iter__()
def _recv(self):
pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment