Commit 0e5c83b5 authored by Yuxin Wu's avatar Yuxin Wu

BatchData supports dict (fix #768)

parent 98eb3db5
...@@ -76,7 +76,8 @@ class BatchData(ProxyDataFlow): ...@@ -76,7 +76,8 @@ class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False, use_list=False): def __init__(self, ds, batch_size, remainder=False, use_list=False):
""" """
Args: Args:
ds (DataFlow): When ``use_list=False``, the components of ``ds`` ds (DataFlow): A dataflow that produces either list or dict.
When ``use_list=False``, the components of ``ds``
must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes. must be either scalars or :class:`np.ndarray`, and have to be consistent in shapes.
batch_size(int): batch size batch_size(int): batch size
remainder (bool): When the remaining datapoints in ``ds`` is not remainder (bool): When the remaining datapoints in ``ds`` is not
...@@ -119,42 +120,54 @@ class BatchData(ProxyDataFlow): ...@@ -119,42 +120,54 @@ class BatchData(ProxyDataFlow):
if self.remainder and len(holder) > 0: if self.remainder and len(holder) > 0:
yield BatchData._aggregate_batch(holder, self.use_list) yield BatchData._aggregate_batch(holder, self.use_list)
@staticmethod
def _batch_numpy(data_list):
data = data_list[0]
if isinstance(data, six.integer_types):
dtype = 'int32'
elif type(data) == bool:
dtype = 'bool'
elif type(data) == float:
dtype = 'float32'
elif isinstance(data, (six.binary_type, six.text_type)):
dtype = 'str'
else:
try:
dtype = data.dtype
except AttributeError:
raise TypeError("Unsupported type to batch: {}".format(type(data)))
try:
return np.asarray(data_list, dtype=dtype)
except Exception as e: # noqa
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
if isinstance(data, np.ndarray):
s = pprint.pformat([x.shape for x in data_list])
logger.error("Shape of all arrays to be batched: " + s)
try:
# open an ipython shell if possible
import IPython as IP; IP.embed() # noqa
except ImportError:
pass
@staticmethod @staticmethod
def _aggregate_batch(data_holder, use_list=False): def _aggregate_batch(data_holder, use_list=False):
size = len(data_holder[0]) first_dp = data_holder[0]
result = [] if isinstance(first_dp, (list, tuple)):
for k in range(size): result = []
if use_list: for k in range(len(first_dp)):
result.append( data_list = [x[k] for x in data_holder]
[x[k] for x in data_holder]) if use_list:
else: result.append(data_list)
data = data_holder[0][k]
if isinstance(data, six.integer_types):
dtype = 'int32'
elif type(data) == bool:
dtype = 'bool'
elif type(data) == float:
dtype = 'float32'
elif isinstance(data, (six.binary_type, six.text_type)):
dtype = 'str'
else: else:
try: result.append(BatchData._batch_numpy(data_list))
dtype = data.dtype elif isinstance(first_dp, dict):
except AttributeError: result = []
raise TypeError("Unsupported type to batch: {}".format(type(data))) for key in first_dp.keys():
try: data_list = [x[k] for x in data_holder]
result.append( if use_list:
np.asarray([x[k] for x in data_holder], dtype=dtype)) result[key] = data_list
except Exception as e: # noqa else:
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?") result[key] = BatchData._batch_numpy(data_list)
if isinstance(data, np.ndarray):
s = pprint.pformat([x[k].shape for x in data_holder])
logger.error("Shape of all arrays to be batched: " + s)
try:
# open an ipython shell if possible
import IPython as IP; IP.embed() # noqa
except ImportError:
pass
return result return result
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment