Commit 0e5c83b5 authored by Yuxin Wu's avatar Yuxin Wu

BatchData supports dict (fix #768)

parent 98eb3db5
...@@ -76,7 +76,8 @@ class BatchData(ProxyDataFlow): ...@@ -76,7 +76,8 @@ class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False, use_list=False): def __init__(self, ds, batch_size, remainder=False, use_list=False):
""" """
Args: Args:
ds (DataFlow): When ``use_list=False``, the components of ``ds`` ds (DataFlow): A dataflow that produces either list or dict.
When ``use_list=False``, the components of ``ds``
must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes. must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes.
batch_size(int): batch size batch_size(int): batch size
remainder (bool): When the remaining datapoints in ``ds`` is not remainder (bool): When the remaining datapoints in ``ds`` is not
...@@ -120,15 +121,8 @@ class BatchData(ProxyDataFlow): ...@@ -120,15 +121,8 @@ class BatchData(ProxyDataFlow):
yield BatchData._aggregate_batch(holder, self.use_list) yield BatchData._aggregate_batch(holder, self.use_list)
@staticmethod @staticmethod
def _aggregate_batch(data_holder, use_list=False): def _batch_numpy(data_list):
size = len(data_holder[0]) data = data_list[0]
result = []
for k in range(size):
if use_list:
result.append(
[x[k] for x in data_holder])
else:
data = data_holder[0][k]
if isinstance(data, six.integer_types): if isinstance(data, six.integer_types):
dtype = 'int32' dtype = 'int32'
elif type(data) == bool: elif type(data) == bool:
...@@ -143,18 +137,37 @@ class BatchData(ProxyDataFlow): ...@@ -143,18 +137,37 @@ class BatchData(ProxyDataFlow):
except AttributeError: except AttributeError:
raise TypeError("Unsupported type to batch: {}".format(type(data))) raise TypeError("Unsupported type to batch: {}".format(type(data)))
try: try:
result.append( return np.asarray(data_list, dtype=dtype)
np.asarray([x[k] for x in data_holder], dtype=dtype))
except Exception as e: # noqa except Exception as e: # noqa
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?") logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
s = pprint.pformat([x[k].shape for x in data_holder]) s = pprint.pformat([x.shape for x in data_list])
logger.error("Shape of all arrays to be batched: " + s) logger.error("Shape of all arrays to be batched: " + s)
try: try:
# open an ipython shell if possible # open an ipython shell if possible
import IPython as IP; IP.embed() # noqa import IPython as IP; IP.embed() # noqa
except ImportError: except ImportError:
pass pass
@staticmethod
def _aggregate_batch(data_holder, use_list=False):
first_dp = data_holder[0]
if isinstance(first_dp, (list, tuple)):
result = []
for k in range(len(first_dp)):
data_list = [x[k] for x in data_holder]
if use_list:
result.append(data_list)
else:
result.append(BatchData._batch_numpy(data_list))
elif isinstance(first_dp, dict):
result = []
for key in first_dp.keys():
data_list = [x[k] for x in data_holder]
if use_list:
result[key] = data_list
else:
result[key] = BatchData._batch_numpy(data_list)
return result return result
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment