Commit 98eb3db5 authored by Yuxin Wu's avatar Yuxin Wu

More data types support in BatchData (fix #983)

parent 135d17e2
...@@ -214,7 +214,7 @@ def fastrcnn_predictions(boxes, scores): ...@@ -214,7 +214,7 @@ def fastrcnn_predictions(boxes, scores):
# sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING') # sorted_selection = tf.contrib.framework.sort(selection, direction='ASCENDING')
sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0] sorted_selection = -tf.nn.top_k(-selection, k=tf.size(selection))[0]
if get_tf_version_tuple() >= (1, 12): if get_tf_version_tuple() >= (1, 13):
mask = tf.sparse.SparseTensor(indices=tf.expand_dims(sorted_selection, 1), mask = tf.sparse.SparseTensor(indices=tf.expand_dims(sorted_selection, 1),
values=tf.ones_like(sorted_selection, dtype=tf.bool), values=tf.ones_like(sorted_selection, dtype=tf.bool),
dense_shape=output_shape) dense_shape=output_shape)
......
...@@ -522,6 +522,7 @@ if __name__ == '__main__': ...@@ -522,6 +522,7 @@ if __name__ == '__main__':
MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model() MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model()
if args.visualize or args.evaluate or args.predict: if args.visualize or args.evaluate or args.predict:
assert tf.test.is_gpu_available()
assert args.load assert args.load
finalize_configs(is_training=False) finalize_configs(is_training=False)
......
...@@ -128,22 +128,26 @@ class BatchData(ProxyDataFlow): ...@@ -128,22 +128,26 @@ class BatchData(ProxyDataFlow):
result.append( result.append(
[x[k] for x in data_holder]) [x[k] for x in data_holder])
else: else:
dt = data_holder[0][k] data = data_holder[0][k]
if type(dt) in list(six.integer_types) + [bool]: if isinstance(data, six.integer_types):
tp = 'int32' dtype = 'int32'
elif type(dt) == float: elif type(data) == bool:
tp = 'float32' dtype = 'bool'
elif type(data) == float:
dtype = 'float32'
elif isinstance(data, (six.binary_type, six.text_type)):
dtype = 'str'
else: else:
try: try:
tp = dt.dtype dtype = data.dtype
except AttributeError: except AttributeError:
raise TypeError("Unsupported type to batch: {}".format(type(dt))) raise TypeError("Unsupported type to batch: {}".format(type(data)))
try: try:
result.append( result.append(
np.asarray([x[k] for x in data_holder], dtype=tp)) np.asarray([x[k] for x in data_holder], dtype=dtype))
except Exception as e: # noqa except Exception as e: # noqa
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?") logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
if isinstance(dt, np.ndarray): if isinstance(data, np.ndarray):
s = pprint.pformat([x[k].shape for x in data_holder]) s = pprint.pformat([x[k].shape for x in data_holder])
logger.error("Shape of all arrays to be batched: " + s) logger.error("Shape of all arrays to be batched: " + s)
try: try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment