Commit 192de99a authored by Yuxin Wu's avatar Yuxin Wu

fix bug in FasterRCNN SparseTensor & parallel map data

parent 8132387f
...@@ -208,21 +208,21 @@ def fastrcnn_predictions(boxes, scores): ...@@ -208,21 +208,21 @@ def fastrcnn_predictions(boxes, scores):
# NMS within each class # NMS within each class
selection = tf.image.non_max_suppression( selection = tf.image.non_max_suppression(
box, prob, cfg.TEST.RESULTS_PER_IM, cfg.TEST.FRCNN_NMS_THRESH) box, prob, cfg.TEST.RESULTS_PER_IM, cfg.TEST.FRCNN_NMS_THRESH)
selection = tf.to_int32(tf.gather(ids, selection)) selection = tf.gather(ids, selection)
# sort available in TF>1.4.0 # sort available in TF>1.4.0
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING') # sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0] sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0]
if get_tf_version_tuple() >= (1, 12): if get_tf_version_tuple() >= (1, 12):
mask = tf.sparse.SparseTensor(indices=sorted_selection, mask = tf.sparse.SparseTensor(indices=tf.expand_dims(sorted_selection, 1),
values=tf.ones_like(sorted_selection, dtype=tf.bool), values=tf.ones_like(sorted_selection, dtype=tf.bool),
dense_shape=tf.shape(prob)) dense_shape=tf.shape(prob, out_type=tf.int64))
mask = tf.sparse.to_dense(mask, default_value=False) mask = tf.sparse.to_dense(mask, default_value=False)
else: else:
# deprecated by TF # this function is deprecated by TF
mask = tf.sparse_to_dense( mask = tf.sparse_to_dense(
sparse_indices=sorted_selection, sparse_indices=sorted_selection,
output_shape=tf.shape(prob), output_shape=tf.shape(prob, out_type=tf.int64),
sparse_values=True, sparse_values=True,
default_value=False) default_value=False)
return mask return mask
......
...@@ -25,8 +25,6 @@ __all__ = ['ThreadedMapData', 'MultiThreadMapData', ...@@ -25,8 +25,6 @@ __all__ = ['ThreadedMapData', 'MultiThreadMapData',
class _ParallelMapData(ProxyDataFlow): class _ParallelMapData(ProxyDataFlow):
def __init__(self, ds, buffer_size, strict=False): def __init__(self, ds, buffer_size, strict=False):
if not strict:
ds = RepeatedData(ds, -1)
super(_ParallelMapData, self).__init__(ds) super(_ParallelMapData, self).__init__(ds)
assert buffer_size > 0, buffer_size assert buffer_size > 0, buffer_size
self._buffer_size = buffer_size self._buffer_size = buffer_size
...@@ -35,7 +33,11 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -35,7 +33,11 @@ class _ParallelMapData(ProxyDataFlow):
def reset_state(self): def reset_state(self):
super(_ParallelMapData, self).reset_state() super(_ParallelMapData, self).reset_state()
self._iter = self.ds.__iter__() if not self._strict:
ds = RepeatedData(self.ds, -1)
else:
ds = self.ds
self._iter = ds.__iter__()
def _recv(self): def _recv(self):
pass pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment