Commit b6a4f429 authored by Yuxin Wu's avatar Yuxin Wu

a correct size() for JoinData. (#217)

parent 502c461d
......@@ -406,7 +406,8 @@ class JoinData(DataFlow):
"""
Args:
df_lists (list): a list of DataFlow.
All must have the same ``size()``, or don't have size.
When these dataflows have different sizes, JoinData will stop when any
of them is exhausted.
"""
self.df_lists = df_lists
......@@ -423,7 +424,7 @@ class JoinData(DataFlow):
d.reset_state()
def size(self):
return self.df_lists[0].size()
return min([k.size() for k in self.df_lists])
def get_data(self):
itrs = [k.get_data() for k in self.df_lists]
......@@ -433,7 +434,7 @@ class JoinData(DataFlow):
for itr in itrs:
dp.extend(next(itr))
yield dp
except StopIteration:
except StopIteration: # some of them are exhausted
pass
finally:
for itr in itrs:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment