Commit b6a4f429 authored by Yuxin Wu's avatar Yuxin Wu

a correct size() for JoinData. (#217)

parent 502c461d
...@@ -406,7 +406,8 @@ class JoinData(DataFlow): ...@@ -406,7 +406,8 @@ class JoinData(DataFlow):
""" """
Args: Args:
df_lists (list): a list of DataFlow. df_lists (list): a list of DataFlow.
All must have the same ``size()``, or don't have size. When these dataflows have different sizes, JoinData will stop when any
of them is exhausted.
""" """
self.df_lists = df_lists self.df_lists = df_lists
...@@ -423,7 +424,7 @@ class JoinData(DataFlow): ...@@ -423,7 +424,7 @@ class JoinData(DataFlow):
d.reset_state() d.reset_state()
def size(self): def size(self):
return self.df_lists[0].size() return min([k.size() for k in self.df_lists])
def get_data(self): def get_data(self):
itrs = [k.get_data() for k in self.df_lists] itrs = [k.get_data() for k in self.df_lists]
...@@ -433,7 +434,7 @@ class JoinData(DataFlow): ...@@ -433,7 +434,7 @@ class JoinData(DataFlow):
for itr in itrs: for itr in itrs:
dp.extend(next(itr)) dp.extend(next(itr))
yield dp yield dp
except StopIteration: except StopIteration: # some of them are exhausted
pass pass
finally: finally:
for itr in itrs: for itr in itrs:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment