Commit 0e5c83b5 authored by Yuxin Wu's avatar Yuxin Wu

BatchData supports dict (fix #768)

parent 98eb3db5
......@@ -76,7 +76,8 @@ class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False, use_list=False):
"""
Args:
ds (DataFlow): When ``use_list=False``, the components of ``ds``
ds (DataFlow): A dataflow that produces either list or dict.
When ``use_list=False``, the components of ``ds``
must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes.
batch_size(int): batch size
remainder (bool): When the remaining datapoints in ``ds`` is not
......@@ -119,42 +120,54 @@ class BatchData(ProxyDataFlow):
if self.remainder and len(holder) > 0:
yield BatchData._aggregate_batch(holder, self.use_list)
@staticmethod
def _batch_numpy(data_list):
data = data_list[0]
if isinstance(data, six.integer_types):
dtype = 'int32'
elif type(data) == bool:
dtype = 'bool'
elif type(data) == float:
dtype = 'float32'
elif isinstance(data, (six.binary_type, six.text_type)):
dtype = 'str'
else:
try:
dtype = data.dtype
except AttributeError:
raise TypeError("Unsupported type to batch: {}".format(type(data)))
try:
return np.asarray(data_list, dtype=dtype)
except Exception as e: # noqa
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
if isinstance(data, np.ndarray):
s = pprint.pformat([x.shape for x in data_list])
logger.error("Shape of all arrays to be batched: " + s)
try:
# open an ipython shell if possible
import IPython as IP; IP.embed() # noqa
except ImportError:
pass
@staticmethod
def _aggregate_batch(data_holder, use_list=False):
size = len(data_holder[0])
result = []
for k in range(size):
if use_list:
result.append(
[x[k] for x in data_holder])
else:
data = data_holder[0][k]
if isinstance(data, six.integer_types):
dtype = 'int32'
elif type(data) == bool:
dtype = 'bool'
elif type(data) == float:
dtype = 'float32'
elif isinstance(data, (six.binary_type, six.text_type)):
dtype = 'str'
first_dp = data_holder[0]
if isinstance(first_dp, (list, tuple)):
result = []
for k in range(len(first_dp)):
data_list = [x[k] for x in data_holder]
if use_list:
result.append(data_list)
else:
try:
dtype = data.dtype
except AttributeError:
raise TypeError("Unsupported type to batch: {}".format(type(data)))
try:
result.append(
np.asarray([x[k] for x in data_holder], dtype=dtype))
except Exception as e: # noqa
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
if isinstance(data, np.ndarray):
s = pprint.pformat([x[k].shape for x in data_holder])
logger.error("Shape of all arrays to be batched: " + s)
try:
# open an ipython shell if possible
import IPython as IP; IP.embed() # noqa
except ImportError:
pass
result.append(BatchData._batch_numpy(data_list))
elif isinstance(first_dp, dict):
result = []
for key in first_dp.keys():
data_list = [x[k] for x in data_holder]
if use_list:
result[key] = data_list
else:
result[key] = BatchData._batch_numpy(data_list)
return result
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment