Commit da5e9e66 authored by Yuxin Wu's avatar Yuxin Wu

add allow_list to BatchData. Default to False for now to test

parent 7c694aca
...@@ -47,9 +47,14 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -47,9 +47,14 @@ class TestDataSpeed(ProxyDataFlow):
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
""" """
Group data into batches. Concat datapoints into batches.
It produces datapoints of the same number of components as ``ds``, but
each component has one new extra dimension of size ``batch_size``.
The new component can be a list of the original datapoints, or an ndarray
of the original datapoints.
""" """
def __init__(self, ds, batch_size, remainder=False):
def __init__(self, ds, batch_size, remainder=False, allow_list=False):
""" """
Args: Args:
ds (DataFlow): Its components must be either scalars or :class:`np.ndarray`. ds (DataFlow): Its components must be either scalars or :class:`np.ndarray`.
...@@ -58,6 +63,9 @@ class BatchData(ProxyDataFlow): ...@@ -58,6 +63,9 @@ class BatchData(ProxyDataFlow):
remainder (bool): whether to return the remaining data smaller than a batch_size. remainder (bool): whether to return the remaining data smaller than a batch_size.
If set True, it will possibly generates a data point of a smaller batch size. If set True, it will possibly generates a data point of a smaller batch size.
Otherwise, all generated data are guranteed to have the same size. Otherwise, all generated data are guranteed to have the same size.
allow_list (bool): if True, it will run faster by producing a list
of datapoints instead of an ndarray of datapoints, avoiding an
extra copy.
""" """
super(BatchData, self).__init__(ds) super(BatchData, self).__init__(ds)
if not remainder: if not remainder:
...@@ -67,6 +75,7 @@ class BatchData(ProxyDataFlow): ...@@ -67,6 +75,7 @@ class BatchData(ProxyDataFlow):
pass pass
self.batch_size = batch_size self.batch_size = batch_size
self.remainder = remainder self.remainder = remainder
self.allow_list = allow_list
def size(self): def size(self):
ds_size = self.ds.size() ds_size = self.ds.size()
...@@ -85,32 +94,36 @@ class BatchData(ProxyDataFlow): ...@@ -85,32 +94,36 @@ class BatchData(ProxyDataFlow):
for data in self.ds.get_data(): for data in self.ds.get_data():
holder.append(data) holder.append(data)
if len(holder) == self.batch_size: if len(holder) == self.batch_size:
yield BatchData._aggregate_batch(holder) yield BatchData._aggregate_batch(holder, self.allow_list)
del holder[:] del holder[:]
if self.remainder and len(holder) > 0: if self.remainder and len(holder) > 0:
yield BatchData._aggregate_batch(holder) yield BatchData._aggregate_batch(holder, self.allow_list)
@staticmethod @staticmethod
def _aggregate_batch(data_holder): def _aggregate_batch(data_holder, allow_list):
size = len(data_holder[0]) size = len(data_holder[0])
result = [] result = []
for k in range(size): for k in range(size):
dt = data_holder[0][k] if allow_list:
if type(dt) in [int, bool]:
tp = 'int32'
elif type(dt) == float:
tp = 'float32'
else:
tp = dt.dtype
try:
result.append( result.append(
np.array([x[k] for x in data_holder], dtype=tp)) [x[k] for x in data_holder])
except KeyboardInterrupt: else:
raise dt = data_holder[0][k]
except: if type(dt) in [int, bool]:
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?") tp = 'int32'
import IPython as IP elif type(dt) == float:
IP.embed(config=IP.terminal.ipapp.load_default_config()) tp = 'float32'
else:
tp = dt.dtype
try:
result.append(
np.array([x[k] for x in data_holder], dtype=tp))
except KeyboardInterrupt:
raise
except:
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
import IPython as IP
IP.embed(config=IP.terminal.ipapp.load_default_config())
return result return result
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment