Commit 7b77da36 authored by Yuxin Wu's avatar Yuxin Wu

bug fix in aggregate batch

parent 27d73303
...@@ -102,7 +102,7 @@ class BatchData(ProxyDataFlow): ...@@ -102,7 +102,7 @@ class BatchData(ProxyDataFlow):
yield BatchData._aggregate_batch(holder, self.use_list) yield BatchData._aggregate_batch(holder, self.use_list)
@staticmethod @staticmethod
def _aggregate_batch(data_holder, use_list): def _aggregate_batch(data_holder, use_list=False):
size = len(data_holder[0]) size = len(data_holder[0])
result = [] result = []
for k in range(size): for k in range(size):
...@@ -390,20 +390,25 @@ class JoinData(DataFlow): ...@@ -390,20 +390,25 @@ class JoinData(DataFlow):
def __init__(self, df_lists): def __init__(self, df_lists):
""" """
Args: Args:
df_lists (list): a list of DataFlow. All must have the same ``size()``. df_lists (list): a list of DataFlow.
All must have the same ``size()``, or don't have size.
""" """
self.df_lists = df_lists self.df_lists = df_lists
self._size = self.df_lists[0].size()
for d in self.df_lists: try:
assert d.size() == self._size, \ self._size = self.df_lists[0].size()
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size) for d in self.df_lists:
assert d.size() == self._size, \
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size)
except Exception:
logger.info("[JoinData] Size check failed for the list of dataflow to be joined!")
def reset_state(self): def reset_state(self):
for d in self.df_lists: for d in self.df_lists:
d.reset_state() d.reset_state()
def size(self): def size(self):
return self._size return self.df_lists[0].size()
def get_data(self): def get_data(self):
itrs = [k.get_data() for k in self.df_lists] itrs = [k.get_data() for k in self.df_lists]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment