Commit 7b8728f9 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

make dataflow idiomatic python container objects (fix #869) (#872)

* make dataflow idiomatic python container objects (fix #869)

* use ABCMeta and fix tqdm

* python2 compatible version

* fix

* some more replacement

* rollback changes of _reset_df_and_get_size

* fix unrelated pbar issue

* revert the __len__ changes on InputSource
parent 3476cb43
......@@ -56,7 +56,7 @@ To actually access the datapoints generated by the dataflow, you can use
```python
ds.reset_state()
for dp in ds.get_data():
for dp in ds:
print dp[0] # this is an RGB image!
```
This kind of iteration is used behind the scenes to feed data for training.
......
......@@ -5,11 +5,10 @@
DataFlow is a library to build Python iterators for efficient data loading.
**Definition**: A DataFlow is something that has a `get_data()` generator method,
which yields `datapoints`.
**Definition**: A DataFlow is a idiomatic Python container object that has a `__iter__()` generator method, which yields `datapoints` and a `__len__()` method returning the size of the flow.
A datapoint is a **list** of Python objects which are called the `components` of a datapoint.
**Example**: to train on MNIST dataset, you may need a DataFlow with a `get_data()` method
**Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method
that yields datapoints (lists) of two components:
a numpy array of shape (64, 28, 28), and an array of shape (64,).
......@@ -67,7 +66,6 @@ and then use the generator however you like:
df = SomeDataFlow()
df.reset_state()
generator = df.get_data()
for dp in generator:
for dp in df:
# dp is now a list. do whatever
```
......@@ -143,8 +143,8 @@ We can also dump the dataset into one single LMDB file and read it sequentially.
```python
from tensorpack.dataflow import *
class BinaryILSVRC12(dataset.ILSVRC12Files):
def get_data(self):
for fname, label in super(BinaryILSVRC12, self).get_data():
def __iter__(self):
for fname, label in super(BinaryILSVRC12, self).__iter__():
with open(fname, 'rb') as f:
jpeg = f.read()
jpeg = np.asarray(bytearray(jpeg), dtype='uint8')
......
......@@ -23,7 +23,7 @@ To write more complicated DataFlow, you need to inherit the base `DataFlow` clas
Usually, you just need to implement the `get_data()` method which yields a datapoint every time.
```python
class MyDataFlow(DataFlow):
def get_data(self):
def __iter__(self):
for k in range(100):
digit = np.random.rand(28, 28)
label = np.random.randint(10)
......@@ -56,7 +56,7 @@ class ProcessingDataFlow(DataFlow):
def __init__(self, ds):
self.ds = ds
def get_data(self):
def __iter__(self):
for datapoint in self.ds.get_data():
# do something
yield new_datapoint
......
......@@ -89,10 +89,10 @@ class RawTIMIT(DataFlow):
assert label in ['phoneme', 'letter'], label
self.label = label
def size(self):
def __len__(self):
return len(self.filelists)
def get_data(self):
def __iter__(self):
for f in self.filelists:
feat = get_feature(f)
if self.label == 'phoneme':
......@@ -106,12 +106,10 @@ def compute_mean_std(db, fname):
ds = LMDBSerializer.load(db, shuffle=False)
ds.reset_state()
o = OnlineMoments()
with get_tqdm(total=ds.size()) as bar:
for dp in ds.get_data():
for dp in get_tqdm(ds):
feat = dp[0] # len x dim
for f in feat:
o.feed(f)
bar.update()
logger.info("Writing to {} ...".format(fname))
with open(fname, 'wb') as f:
f.write(serialize.dumps([o.mean, o.std]))
......
......@@ -39,12 +39,12 @@ class TIMITBatch(ProxyDataFlow):
self.batch = batch
self.ds = ds
def size(self):
return self.ds.size() // self.batch
def __len__(self):
return len(self.ds) // self.batch
def get_data(self):
itr = self.ds.get_data()
for _ in range(self.size()):
def __iter__(self):
itr = self.ds.__iter__()
for _ in range(self.__len__()):
feats = []
labs = []
for b in range(self.batch):
......
......@@ -58,10 +58,10 @@ class CharRNNData(RNGDataFlow):
self.whole_seq = np.array([self.char2idx[c] for c in data], dtype='int32')
logger.info("Corpus loaded. Vocab size: {}".format(self.vocab_size))
def size(self):
def __len__(self):
return self._size
def get_data(self):
def __iter__(self):
random_starts = self.rng.randint(
0, self.whole_seq.shape[0] - self.seq_length - 1, (self._size,))
for st in random_starts:
......
......@@ -241,7 +241,7 @@ class ExpReplay(DataFlow, Callback):
return [state, action, reward, isOver]
# DataFlow method:
def get_data(self):
def __iter__(self):
# wait for memory to be initialized
self._init_memory_flag.wait()
......
......@@ -14,8 +14,8 @@ class DisturbLabel(ProxyDataFlow, RNGDataFlow):
RNGDataFlow.reset_state(self)
ProxyDataFlow.reset_state(self)
def get_data(self):
for dp in self.ds.get_data():
def __iter__(self):
for dp in self.ds:
img, l = dp
if self.rng.rand() < self.prob:
l = self.rng.choice(10)
......
......@@ -214,8 +214,8 @@ class ThetaImages(ProxyDataFlow, RNGDataFlow):
ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self)
def get_data(self):
for image, label in self.ds.get_data():
def __iter__(self):
for image, label in self.ds:
theta = self.rng.uniform(0, 2 * np.pi)
filtered_image, gt_filter = ThetaImages.filter_with_theta(image, theta)
yield [theta, image, filtered_image, gt_filter]
......@@ -245,7 +245,7 @@ def get_config():
OnlineTensorboardExport()
],
model=Model(),
steps_per_epoch=dataset_train.size(),
steps_per_epoch=len(dataset_train),
max_epoch=50,
)
......
......@@ -15,10 +15,10 @@ class DataFromListOfDict(RNGDataFlow):
self._shuffle = shuffle
self._size = len(lst)
def size(self):
def __len__(self):
return self._size
def get_data(self):
def __iter__(self):
if self._shuffle:
self.rng.shuffle(self._lst)
for dic in self._lst:
......
......@@ -412,5 +412,5 @@ if __name__ == '__main__':
ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start()
ds.reset_state()
for k in ds.get_data():
for k in ds:
pass
......@@ -109,7 +109,7 @@ def eval_coco(df, detect_func, tqdm_bar=None):
if tqdm_bar is None:
tqdm_bar = stack.enter_context(
tqdm.tqdm(total=df.size(), **get_tqdm_kwargs()))
for img, img_id in df.get_data():
for img, img_id in df:
results = detect_func(img)
for r in results:
box = r.box
......
......@@ -348,7 +348,7 @@ def visualize(model, model_path, nr_visualize=100, output_dir='output'):
shutil.rmtree(output_dir)
utils.fs.mkdir_p(output_dir)
with tqdm.tqdm(total=nr_visualize) as pbar:
for idx, dp in itertools.islice(enumerate(df.get_data()), nr_visualize):
for idx, dp in itertools.islice(enumerate(df), nr_visualize):
img = dp[0]
if cfg.MODE_MASK:
gt_boxes, gt_labels, gt_masks = dp[-3:]
......
......@@ -193,7 +193,7 @@ class VisualizeTestSet(Callback):
def _trigger(self):
idx = 0
for iA, iB in self.val_ds.get_data():
for iA, iB in self.val_ds:
vizA, vizB = self.pred(iA, iB)
self.trainer.monitors.put_image('testA-{}'.format(idx), vizA)
self.trainer.monitors.put_image('testB-{}'.format(idx), vizB)
......@@ -223,6 +223,6 @@ if __name__ == '__main__':
PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3),
],
max_epoch=195,
steps_per_epoch=data.size(),
steps_per_epoch=len(data),
session_init=SaverRestore(args.load) if args.load else None
)
......@@ -191,6 +191,6 @@ class RandomZData(DataFlow):
super(RandomZData, self).__init__()
self.shape = shape
def get_data(self):
def __iter__(self):
while True:
yield [np.random.uniform(-1, 1, size=self.shape)]
......@@ -223,7 +223,7 @@ if __name__ == '__main__':
PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
],
steps_per_epoch=data.size(),
steps_per_epoch=len(data),
max_epoch=300,
session_init=SaverRestore(args.load) if args.load else None
)
......@@ -237,7 +237,7 @@ def get_data(name):
def view_data():
ds = RepeatedData(get_data('train'), -1)
ds.reset_state()
for ims, edgemaps in ds.get_data():
for ims, edgemaps in ds:
for im, edgemap in zip(ims, edgemaps):
assert im.shape[0] % 16 == 0 and im.shape[1] % 16 == 0, im.shape
cv2.imshow("im", im / 255.0)
......@@ -249,7 +249,7 @@ def view_data():
def get_config():
logger.auto_set_dir()
dataset_train = get_data('train')
steps_per_epoch = dataset_train.size() * 40
steps_per_epoch = len(dataset_train) * 40
dataset_val = get_data('val')
return TrainConfig(
......
......@@ -58,7 +58,7 @@ class SintelData(DataFlow):
def size(self):
return len(self.flows)
def get_data(self):
def __iter__(self):
for flow_path in self.flows:
input_path = flow_path.replace(
self.path_prefix, os.path.join(self.data_path, 'clean'))
......@@ -95,7 +95,7 @@ def inference(model, model_path, sintel_path):
ds.reset_state()
# look at shape information (all images in Sintel has the same shape)
h, w = next(ds.get_data())[0].shape[2:]
h, w = next(ds.__iter__())[0].shape[2:]
pred = PredictConfig(
model=model(height=h, width=w),
......
......@@ -138,8 +138,6 @@ if __name__ == '__main__':
dataset_train = get_data('train', args.mixup, args.alpha)
dataset_test = get_data('test', args.mixup, args.alpha)
steps_per_epoch = dataset_train.size()
config = TrainConfig(
model=ResNet_Cifar(),
data=QueueInput(dataset_train),
......@@ -150,7 +148,7 @@ if __name__ == '__main__':
ScheduledHyperParamSetter('learning_rate', LR_SCHEDULE)
],
max_epoch=200,
steps_per_epoch=steps_per_epoch,
steps_per_epoch=len(dataset_train),
session_init=SaverRestore(args.load) if args.load else None
)
launch_train_with_config(config, SimpleTrainer())
......@@ -41,7 +41,7 @@ class MnistPairs(dataset.Mnist):
idx = self.rng.randint(len(self.data_dict[label]))
return self.data_dict[label][idx].astype(np.float32)
def get_data(self):
def __iter__(self):
while True:
y = self.rng.randint(2)
if y == 0:
......@@ -54,7 +54,7 @@ class MnistPairs(dataset.Mnist):
class MnistTriplets(MnistPairs):
def get_data(self):
def __iter__(self):
while True:
pick_label, pick_other = self.rng.choice(10, size=2, replace=False)
yield [self.pick(pick_label), self.pick(pick_label), self.pick(pick_other)]
......@@ -379,7 +379,7 @@ def visualize(model_path, model, algo_name):
ds = get_test_data()
ds.reset_state()
for offset, dp in enumerate(ds.get_data()):
for offset, dp in enumerate(ds):
digit, label = dp
prediction = pred(digit)[0]
embed[offset * BATCH_SIZE:offset * BATCH_SIZE + BATCH_SIZE, ...] = prediction
......
......@@ -222,7 +222,7 @@ def view_warp(modelpath):
ds = get_data(False)
ds.reset_state()
for k in ds.get_data():
for k in ds:
img, label = k
outputs, affine1, affine2 = pred(img)
for idx, viz in enumerate(outputs):
......@@ -238,7 +238,7 @@ def get_config():
logger.auto_set_dir()
dataset_train, dataset_test = get_data(True), get_data(False)
steps_per_epoch = dataset_train.size() * 5
steps_per_epoch = len(dataset_train) * 5
return TrainConfig(
model=Model(),
......
......@@ -34,7 +34,7 @@ class ImageDataFromZIPFile(RNGDataFlow):
def size(self):
return len(self.archivefiles)
def get_data(self):
def __iter__(self):
if self.shuffle:
self.rng.shuffle(self.archivefiles)
for archive in self.archivefiles:
......@@ -111,6 +111,6 @@ if __name__ == '__main__':
LMDBSerializer.save(ds, args.lmdb)
if args.debug:
ds.reset_state()
for i in ds.get_data():
for i in ds:
cv2.imshow('example', i[0])
cv2.waitKey(0)
......@@ -298,6 +298,6 @@ if __name__ == '__main__':
ModelSaver(keep_checkpoint_every_n_hours=2)
],
session_init=session_init,
steps_per_epoch=data.size() // 4,
steps_per_epoch=len(data) // 4,
max_epoch=300
)
......@@ -104,8 +104,8 @@ if __name__ == '__main__':
dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch.
# This (data.size()) is the default value.
steps_per_epoch = dataset_train.size()
# This len(data) is the default value.
steps_per_epoch = len(dataset_train)
# get the config which contains everything necessary in a training
config = TrainConfig(
......
......@@ -108,8 +108,8 @@ if __name__ == '__main__':
dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch.
# This (data.size()) is the default value.
steps_per_epoch = dataset_train.size()
# This len(data) is the default value.
steps_per_epoch = len(dataset_train)
# get the config which contains everything necessary in a training
config = TrainConfig(
......
......@@ -132,7 +132,7 @@ if __name__ == '__main__':
InferenceRunner(
dataset_test, ScalarStats(['cross_entropy_loss', 'accuracy'])),
],
steps_per_epoch=dataset_train.size(),
steps_per_epoch=len(dataset_train),
max_epoch=100,
)
......
......@@ -55,7 +55,7 @@ def get_config():
ModelSaver(),
InferenceRunner(ds_test, [ScalarStats('total_costs')]),
],
steps_per_epoch=ds_train.size(),
steps_per_epoch=len(ds_train),
max_epoch=100,
)
......
......@@ -69,7 +69,7 @@ if __name__ == '__main__':
)
M.fit(
validation_data=dataset_test,
steps_per_epoch=dataset_train.size(),
steps_per_epoch=len(dataset_train),
callbacks=[
ModelSaver()
]
......
......@@ -51,7 +51,7 @@ def _inference_context():
yield
except (StopIteration, tf.errors.CancelledError):
logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg)
"[InferenceRunner] input stopped before reaching its __len__()! " + msg)
raise
except tf.errors.OutOfRangeError: # tf.data reaches an end
pass
......@@ -63,7 +63,7 @@ class InferenceRunnerBase(Callback):
Note:
1. InferenceRunner will use `input.size()` to determine
how much iterations to run, so you're responsible to ensure that
`size()` is reasonable.
`input.size()` is reasonable.
2. Only works with instances of `TowerTrainer`.
"""
def __init__(self, input, infs):
......
......@@ -39,12 +39,30 @@ class DataFlowReentrantGuard(object):
return False
@six.add_metaclass(ABCMeta)
# NOTE: we cannot use six here
class DataFlowMeta(ABCMeta):
"""
DataFlow uses "__iter__()" and "__len__()" instead of
"get_data()" and "size()". This add back-compatibility.
"""
def __new__(mcls, name, bases, namespace, **kwargs):
def hot_patch(required, existing):
if required not in namespace and existing in namespace:
namespace[required] = namespace[existing]
hot_patch('__iter__', 'get_data')
hot_patch('__len__', 'size')
return ABCMeta.__new__(mcls, name, bases, namespace, **kwargs)
@six.add_metaclass(DataFlowMeta)
class DataFlow(object):
""" Base class for all DataFlow """
@abstractmethod
def get_data(self):
def __iter__(self):
"""
The method to generate datapoints.
......@@ -52,7 +70,10 @@ class DataFlow(object):
list: The datapoint, i.e. list of components.
"""
def size(self):
def get_data(self):
return self.__iter__()
def __len__(self):
"""
Returns:
int: size of this data flow.
......@@ -62,6 +83,9 @@ class DataFlow(object):
"""
raise NotImplementedError()
def size(self):
return self.__len__()
def reset_state(self):
"""
Reset state of the dataflow.
......@@ -102,8 +126,8 @@ class ProxyDataFlow(DataFlow):
def reset_state(self):
self.ds.reset_state()
def size(self):
return self.ds.size()
def __len__(self):
return self.ds.__len__()
def get_data(self):
return self.ds.get_data()
def __iter__(self):
return self.ds.__iter__()
......@@ -36,10 +36,10 @@ class TestDataSpeed(ProxyDataFlow):
self.test_size = int(size)
self.warmup = int(warmup)
def get_data(self):
def __iter__(self):
""" Will run testing at the beginning, then produce data normally. """
self.start_test()
for dp in self.ds.get_data():
for dp in self.ds:
yield dp
def start_test(self):
......@@ -47,7 +47,7 @@ class TestDataSpeed(ProxyDataFlow):
Start testing with a progress bar.
"""
self.ds.reset_state()
itr = self.ds.get_data()
itr = self.ds.__iter__()
if self.warmup:
for d in tqdm.trange(self.warmup, **get_tqdm_kwargs()):
next(itr)
......@@ -85,35 +85,35 @@ class BatchData(ProxyDataFlow):
enough to form a batch, whether or not to also produce the remaining
data as a smaller batch.
If set to False, all produced datapoints are guaranteed to have the same batch size.
If set to True, `ds.size()` must be accurate.
If set to True, `len(ds)` must be accurate.
use_list (bool): if True, each component will contain a list
of datapoints instead of an numpy array of an extra dimension.
"""
super(BatchData, self).__init__(ds)
if not remainder:
try:
assert batch_size <= ds.size()
assert batch_size <= len(ds)
except NotImplementedError:
pass
self.batch_size = int(batch_size)
self.remainder = remainder
self.use_list = use_list
def size(self):
ds_size = self.ds.size()
def __len__(self):
ds_size = len(self.ds)
div = ds_size // self.batch_size
rem = ds_size % self.batch_size
if rem == 0:
return div
return div + int(self.remainder)
def get_data(self):
def __iter__(self):
"""
Yields:
Batched data by stacking each component on an extra 0th dimension.
"""
holder = []
for data in self.ds.get_data():
for data in self.ds:
holder.append(data)
if len(holder) == self.batch_size:
yield BatchData._aggregate_batch(holder, self.use_list)
......@@ -184,9 +184,9 @@ class BatchDataByShape(BatchData):
super(BatchDataByShape, self).reset_state()
self.holder = defaultdict(list)
def get_data(self):
def __iter__(self):
with self._guard:
for dp in self.ds.get_data():
for dp in self.ds:
shp = dp[self.idx].shape
holder = self.holder[shp]
holder.append(dp)
......@@ -204,7 +204,7 @@ class FixedSizeData(ProxyDataFlow):
ds (DataFlow): input dataflow
size (int): size
keep_state (bool): keep the iterator state of ``ds``
between calls to :meth:`get_data()`, so that the
between calls to :meth:`__iter__()`, so that the
next call will continue the previous iteration over ``ds``,
instead of reinitializing an iterator.
......@@ -223,23 +223,23 @@ class FixedSizeData(ProxyDataFlow):
self._guard = DataFlowReentrantGuard()
self._keep = keep_state
def size(self):
def __len__(self):
return self._size
def reset_state(self):
super(FixedSizeData, self).reset_state()
self.itr = self.ds.get_data()
self.itr = self.ds.__iter__()
def get_data(self):
def __iter__(self):
with self._guard:
if self.itr is None:
self.itr = self.ds.get_data()
self.itr = self.ds.__iter__()
cnt = 0
while True:
try:
dp = next(self.itr)
except StopIteration:
self.itr = self.ds.get_data()
self.itr = self.ds.__iter__()
dp = next(self.itr)
cnt += 1
......@@ -257,7 +257,7 @@ class MapData(ProxyDataFlow):
Note:
1. Please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
2. If you discard some datapoints, ``len(ds)`` will be incorrect.
"""
def __init__(self, ds, func):
......@@ -270,8 +270,8 @@ class MapData(ProxyDataFlow):
super(MapData, self).__init__(ds)
self.func = func
def get_data(self):
for dp in self.ds.get_data():
def __iter__(self):
for dp in self.ds:
ret = self.func(copy(dp)) # shallow copy the list
if ret is not None:
yield ret
......@@ -285,7 +285,7 @@ class MapDataComponent(MapData):
1. This dataflow itself doesn't modify the datapoints.
But please make sure func doesn't modify the components
unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect.
2. If you discard some datapoints, ``len(ds)`` will be incorrect.
"""
def __init__(self, ds, func, index=0):
"""
......@@ -324,23 +324,23 @@ class RepeatedData(ProxyDataFlow):
self.nr = nr
super(RepeatedData, self).__init__(ds)
def size(self):
def __len__(self):
"""
Raises:
:class:`ValueError` when nr == -1.
"""
if self.nr == -1:
raise NotImplementedError("size() is unavailable for infinite dataflow")
return self.ds.size() * self.nr
raise NotImplementedError("__len__() is unavailable for infinite dataflow")
return len(self.ds) * self.nr
def get_data(self):
def __iter__(self):
if self.nr == -1:
while True:
for dp in self.ds.get_data():
for dp in self.ds:
yield dp
else:
for _ in range(self.nr):
for dp in self.ds.get_data():
for dp in self.ds:
yield dp
......@@ -360,11 +360,11 @@ class RepeatedDataPoint(ProxyDataFlow):
assert self.nr >= 1, self.nr
super(RepeatedDataPoint, self).__init__(ds)
def size(self):
return self.ds.size() * self.nr
def __len__(self):
return len(self.ds) * self.nr
def get_data(self):
for dp in self.ds.get_data():
def __iter__(self):
for dp in self.ds:
for _ in range(self.nr):
yield dp
......@@ -397,8 +397,8 @@ class RandomChooseData(RNGDataFlow):
else:
d.reset_state()
def get_data(self):
itrs = [v[0].get_data() for v in self.df_lists]
def __iter__(self):
itrs = [v[0].__iter__() for v in self.df_lists]
probs = np.array([v[1] for v in self.df_lists])
try:
while True:
......@@ -410,35 +410,35 @@ class RandomChooseData(RNGDataFlow):
class RandomMixData(RNGDataFlow):
"""
Perfectly mix datapoints from several DataFlow using their :meth:`size()`.
Will stop when all DataFlow exhausted.
Perfectly mix datapoints from several DataFlow using their
:meth:`__len__()`. Will stop when all DataFlow exhausted.
"""
def __init__(self, df_lists):
"""
Args:
df_lists (list): a list of DataFlow.
All DataFlow must implement ``size()``.
All DataFlow must implement ``__len__()``.
"""
super(RandomMixData, self).__init__()
self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists]
self.sizes = [len(k) for k in self.df_lists]
def reset_state(self):
super(RandomMixData, self).reset_state()
for d in self.df_lists:
d.reset_state()
def size(self):
def __len__(self):
return sum(self.sizes)
def get_data(self):
def __iter__(self):
sums = np.cumsum(self.sizes)
idxs = np.arange(self.size())
idxs = np.arange(self.__len__())
self.rng.shuffle(idxs)
idxs = np.array(list(map(
lambda x: np.searchsorted(sums, x, 'right'), idxs)))
itrs = [k.get_data() for k in self.df_lists]
itrs = [k.__iter__() for k in self.df_lists]
assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs) - 1)
for k in idxs:
yield next(itrs[k])
......@@ -463,12 +463,12 @@ class ConcatData(DataFlow):
for d in self.df_lists:
d.reset_state()
def size(self):
return sum([x.size() for x in self.df_lists])
def __len__(self):
return sum([len(x) for x in self.df_lists])
def get_data(self):
def __iter__(self):
for d in self.df_lists:
for dp in d.get_data():
for dp in d.__iter__():
yield dp
......@@ -492,15 +492,15 @@ class JoinData(DataFlow):
When these dataflows have different sizes, JoinData will stop when any
of them is exhausted.
The list could contain the same DataFlow instance more than once,
but note that `get_data` will then also be called many times.
but note that `__iter__` will then also be called many times.
"""
self.df_lists = df_lists
try:
self._size = self.df_lists[0].size()
self._size = len(self.df_lists[0])
for d in self.df_lists:
assert d.size() == self._size, \
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size)
assert len(d) == self._size, \
"All DataFlow must have the same size! {} != {}".format(len(d), self._size)
except Exception:
logger.info("[JoinData] Size check failed for the list of dataflow to be joined!")
......@@ -508,14 +508,14 @@ class JoinData(DataFlow):
for d in set(self.df_lists):
d.reset_state()
def size(self):
def __len__(self):
"""
Return the minimum size among all.
"""
return min([k.size() for k in self.df_lists])
return min([len(k) for k in self.df_lists])
def get_data(self):
itrs = [k.get_data() for k in self.df_lists]
def __iter__(self):
itrs = [k.__iter__() for k in self.df_lists]
try:
while True:
dp = []
......@@ -574,7 +574,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def reset_state(self):
ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self)
self.ds_itr = RepeatedData(self.ds, -1).get_data()
self.ds_itr = RepeatedData(self.ds, -1).__iter__()
self.current_cnt = 0
def _add_data(self):
......@@ -582,13 +582,13 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
for _ in range(self.nr_reuse):
self.q.append(dp)
def get_data(self):
def __iter__(self):
with self._guard:
# fill queue
while self.q.maxlen > len(self.q):
self._add_data()
sz = self.size()
sz = self.__len__()
cnt = 0
while True:
self.rng.shuffle(self.q)
......@@ -626,7 +626,7 @@ class CacheData(ProxyDataFlow):
self.rng = get_rng(self)
self.buffer = []
def get_data(self):
def __iter__(self):
with self._guard:
if len(self.buffer):
if self.shuffle:
......@@ -634,7 +634,7 @@ class CacheData(ProxyDataFlow):
for dp in self.buffer:
yield dp
else:
for dp in self.ds.get_data():
for dp in self.ds:
yield dp
self.buffer.append(dp)
......@@ -648,12 +648,12 @@ class PrintData(ProxyDataFlow):
.. code-block:: python
def get_data():
def __iter__():
ds = SomeDataSource('path/to/lmdb')
ds = SomeInscrutableMappings(ds)
ds = PrintData(ds, num=2, max_list=2)
return ds
ds = get_data()
ds = __iter__()
The output looks like:
......@@ -766,8 +766,8 @@ class PrintData(ProxyDataFlow):
msg.append(self._analyze_input_data(entry, k, max_depth=self.max_depth, max_list=self.max_list))
return u'\n'.join(msg)
def get_data(self):
for dp in self.ds.get_data():
def __iter__(self):
for dp in self.ds:
# it is important to place this here! otherwise it mixes the output of multiple PrintData
if self.cnt == 0:
label = ' (%s)' % self.name if self.name is not None else ""
......
......@@ -78,10 +78,10 @@ class BSDS500(RNGDataFlow):
self.data[idx] = im
self.label[idx] = gt
def size(self):
def __len__(self):
return self.data.shape[0]
def get_data(self):
def __iter__(self):
idxs = np.arange(self.data.shape[0])
if self.shuffle:
self.rng.shuffle(idxs)
......@@ -99,6 +99,6 @@ except ImportError:
if __name__ == '__main__':
a = BSDS500('val')
a.reset_state()
for k in a.get_data():
for k in a:
cv2.imshow("haha", k[1].astype('uint8') * 255)
cv2.waitKey(1000)
......@@ -106,10 +106,10 @@ class CifarBase(RNGDataFlow):
self.dir = dir
self.shuffle = shuffle
def size(self):
def __len__(self):
return 50000 if self.train_or_test == 'train' else 10000
def get_data(self):
def __iter__(self):
idxs = np.arange(len(self.data))
if self.shuffle:
self.rng.shuffle(idxs)
......@@ -171,7 +171,7 @@ if __name__ == '__main__':
import cv2
ds.reset_state()
for i, dp in enumerate(ds.get_data()):
for i, dp in enumerate(ds):
if i == 100:
break
img = dp[0]
......
......@@ -163,10 +163,10 @@ class ILSVRC12Files(RNGDataFlow):
fname = os.path.join(self.full_dir, fname)
assert os.path.isfile(fname), fname
def size(self):
def __len__(self):
return len(self.imglist)
def get_data(self):
def __iter__(self):
idxs = np.arange(len(self.imglist))
if self.shuffle:
self.rng.shuffle(idxs)
......@@ -251,8 +251,8 @@ class ILSVRC12(ILSVRC12Files):
There are some CMYK / png images, but cv2 seems robust to them.
https://github.com/tensorflow/models/blob/c0cd713f59cfe44fa049b3120c417cc4079c17e3/research/inception/inception/data/build_imagenet_data.py#L264-L300
"""
def get_data(self):
for fname, label in super(ILSVRC12, self).get_data():
def __iter__(self):
for fname, label in super(ILSVRC12, self).__iter__():
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
yield [im, label]
......@@ -299,7 +299,7 @@ if __name__ == '__main__':
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False)
ds.reset_state()
for k in ds.get_data():
for k in ds:
from IPython import embed
embed()
break
......@@ -99,11 +99,11 @@ class Mnist(RNGDataFlow):
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz')
def size(self):
def __len__(self):
return self.images.shape[0]
def get_data(self):
idxs = list(range(self.size()))
def __iter__(self):
idxs = list(range(self.__len__()))
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
......@@ -133,7 +133,7 @@ class FashionMnist(Mnist):
if __name__ == '__main__':
ds = Mnist('train')
ds.reset_state()
for (img, label) in ds.get_data():
for (img, label) in ds:
from IPython import embed
embed()
break
......@@ -49,10 +49,10 @@ class SVHNDigit(RNGDataFlow):
self.Y[self.Y == 10] = 0
SVHNDigit._Cache[name] = (self.X, self.Y)
def size(self):
def __len__(self):
return self.X.shape[0]
def get_data(self):
def __iter__(self):
n = self.X.shape[0]
idxs = np.arange(n)
if self.shuffle:
......
......@@ -43,7 +43,7 @@ def dump_dataflow_to_process_queue(df, size, nr_consumer):
def run(self):
self.df.reset_state()
try:
for idx, dp in enumerate(self.df.get_data()):
for idx, dp in enumerate(self.df):
self.q.put((idx, dp))
finally:
for _ in range(nr_consumer):
......
......@@ -50,10 +50,10 @@ class HDF5Data(RNGDataFlow):
self._size = lens[0]
self.shuffle = shuffle
def size(self):
def __len__(self):
return self._size
def get_data(self):
def __iter__(self):
idxs = list(range(self._size))
if self.shuffle:
self.rng.shuffle(idxs)
......@@ -132,10 +132,10 @@ class LMDBData(RNGDataFlow):
super(LMDBData, self).reset_state()
self._open_lmdb()
def size(self):
def __len__(self):
return self._size
def get_data(self):
def __iter__(self):
with self._guard:
if not self._shuffle:
c = self._txn.cursor()
......@@ -231,11 +231,11 @@ class SVMLightData(RNGDataFlow):
self.X = np.asarray(self.X.todense())
self.shuffle = shuffle
def size(self):
def __len__(self):
return len(self.y)
def get_data(self):
idxs = np.arange(self.size())
def __iter__(self):
idxs = np.arange(self.__len__())
if self.shuffle:
self.rng.shuffle(idxs)
for id in idxs:
......@@ -248,12 +248,12 @@ class TFRecordData(DataFlow):
self._path = path
self._size = int(size)
def size(self):
def __len__(self):
if self._size:
return self._size
return super(TFRecordData, self).size()
return len(super(TFRecordData, self))
def get_data(self):
def __iter__(self):
gen = tf.python_io.tf_record_iterator(self._path)
for dp in gen:
yield loads(dp)
......
......@@ -63,10 +63,10 @@ class ImageFromFile(RNGDataFlow):
self.resize = resize
self.shuffle = shuffle
def size(self):
def __len__(self):
return len(self.files)
def get_data(self):
def __iter__(self):
if self.shuffle:
self.rng.shuffle(self.files)
for f in self.files:
......
......@@ -159,7 +159,7 @@ class MultiProcessPrefetchData(ProxyDataFlow):
# reset all ds so each process will produce different data
self.ds.reset_state()
while True:
for dp in self.ds.get_data():
for dp in self.ds:
self.queue.put(dp)
def __init__(self, ds, nr_prefetch, nr_proc):
......@@ -175,7 +175,7 @@ However, windows requires more strict picklability on processes, which may \
lead of failure on some of the code.")
super(MultiProcessPrefetchData, self).__init__(ds)
try:
self._size = ds.size()
self._size = len(ds)
except NotImplementedError:
self._size = -1
self.nr_proc = nr_proc
......@@ -191,7 +191,7 @@ lead of failure on some of the code.")
ensure_proc_terminate(self.procs)
start_proc_mask_signal(self.procs)
def get_data(self):
def __iter__(self):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
......@@ -264,7 +264,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
socket.connect(self.conn_name)
try:
while True:
for dp in self.ds.get_data():
for dp in self.ds:
socket.send(dumps(dp), copy=False)
# sigint could still propagate here, e.g. when nested
except KeyboardInterrupt:
......@@ -291,17 +291,17 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
try:
self._size = ds.size()
self._size = ds.__len__()
except NotImplementedError:
self._size = -1
def _recv(self):
return loads(self.socket.recv(copy=False))
def size(self):
return self.ds.size()
def __len__(self):
return self.ds.__len__()
def get_data(self):
def __iter__(self):
with self._guard, _zmq_catch_error('PrefetchDataZMQ'):
for k in itertools.count():
if self._size > 0 and k >= self._size:
......@@ -360,7 +360,7 @@ class MultiThreadPrefetchData(DataFlow):
def run(self):
self.df.reset_state()
try:
for dp in self.df.get_data():
for dp in self.df:
if self.stopped():
return
self.queue_put_stoppable(self.queue, dp)
......@@ -391,10 +391,10 @@ class MultiThreadPrefetchData(DataFlow):
th.df.reset_state()
th.start()
def size(self):
return self.threads[0].size()
def __len__(self):
return self.threads[0].__len__()
def get_data(self):
def __iter__(self):
while True:
yield self.queue.get()
......@@ -419,8 +419,8 @@ class PlasmaPutData(ProxyDataFlow):
super(PlasmaPutData, self).reset_state()
self.client = plasma.connect(self._socket, "", 0)
def get_data(self):
for dp in self.ds.get_data():
def __iter__(self):
for dp in self.ds:
oid = self.client.put(dp)
yield [oid.binary()]
......@@ -440,8 +440,8 @@ class PlasmaGetData(ProxyDataFlow):
super(PlasmaGetData, self).reset_state()
self.client = plasma.connect(self._socket, "", 0)
def get_data(self):
for dp in self.ds.get_data():
def __iter__(self):
for dp in self.ds:
oid = plasma.ObjectID(dp[0])
dp = self.client.get(oid)
yield dp
......
......@@ -61,7 +61,7 @@ class _ParallelMapData(ProxyDataFlow):
if ret is not None:
yield ret
self._iter = self.ds.get_data() # refresh
self._iter = self.ds.__iter__() # refresh
for _ in range(self._buffer_size):
self._send(next(self._iter))
ret = self._recv()
......@@ -73,7 +73,7 @@ class _ParallelMapData(ProxyDataFlow):
for dp in self._iter:
self._send(dp)
yield self._recv_filter_none()
self._iter = self.ds.get_data() # refresh
self._iter = self.ds.__iter__() # refresh
# first clear the buffer, then fill
for k in range(self._buffer_size):
......@@ -100,11 +100,11 @@ class MultiThreadMapData(_ParallelMapData):
2. Threads run in parallel and can take different time to run the
mapping function. Therefore the order of datapoints won't be
preserved, and datapoints from one pass of `df.get_data()` might get
preserved, and datapoints from one pass of `df.__iter__()` might get
mixed with datapoints from the next pass.
You can use **strict mode**, where `MultiThreadMapData.get_data()`
is guaranteed to produce the exact set which `df.get_data()`
You can use **strict mode**, where `MultiThreadMapData.__iter__()`
is guaranteed to produce the exact set which `df.__iter__()`
produces. Although the order of data still isn't preserved.
"""
class _Worker(StoppableThread):
......@@ -165,7 +165,7 @@ class MultiThreadMapData(_ParallelMapData):
for t in self._threads:
t.start()
self._iter = self.ds.get_data()
self._iter = self.ds.__iter__()
self._guard = DataFlowReentrantGuard()
# Call once at the beginning, to ensure inq+outq has a total of buffer_size elements
......@@ -177,7 +177,7 @@ class MultiThreadMapData(_ParallelMapData):
def _send(self, dp):
self._in_queue.put(dp)
def get_data(self):
def __iter__(self):
with self._guard:
if self._strict:
for dp in self.get_data_strict():
......@@ -208,11 +208,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
Note:
1. Processes run in parallel and can take different time to run the
mapping function. Therefore the order of datapoints won't be
preserved, and datapoints from one pass of `df.get_data()` might get
preserved, and datapoints from one pass of `df.__iter__()` might get
mixed with datapoints from the next pass.
You can use **strict mode**, where `MultiProcessMapData.get_data()`
is guaranteed to produce the exact set which `df.get_data()`
You can use **strict mode**, where `MultiProcessMapData.__iter__()`
is guaranteed to produce the exact set which `df.__iter__()`
produces. Although the order of data still isn't preserved.
"""
class _Worker(mp.Process):
......@@ -267,7 +267,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
for k in range(self.nr_proc)]
self.ds.reset_state()
self._iter = self.ds.get_data()
self._iter = self.ds.__iter__()
self._start_processes()
self._fill_buffer() # pre-fill the bufer
......@@ -284,7 +284,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
dp = loads(msg[1])
return dp
def get_data(self):
def __iter__(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'):
if self._strict:
for dp in self.get_data_strict():
......@@ -362,13 +362,13 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
arr = mp.RawArray(ctype, int(np.prod(self.output_shape)))
return arr
def size(self):
return self.ds.size()
def __len__(self):
return len(self.ds)
def reset_state(self):
self.ds.reset_state()
def get_data(self):
def __iter__(self):
ds_itr = _repeat_iter(self.ds.get_data)
with self._guard:
while True:
......@@ -392,16 +392,16 @@ if __name__ == '__main__':
def __init__(self, size):
self._size = size
def get_data(self):
def __iter__(self):
for k in range(self._size):
yield [k]
def size(self):
def __len__(self):
return self._size
ds = Zero(300)
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1], strict=True)
ds.reset_state()
for k in ds.get_data():
for k in ds:
print("Bang!", k)
print("END!")
......@@ -34,10 +34,10 @@ class FakeData(RNGDataFlow):
assert len(self.dtype) == len(self.shapes)
assert len(self.domain) == len(self.domain)
def size(self):
def __len__(self):
return self._size
def get_data(self):
def __iter__(self):
if self.random:
for _ in range(self._size):
val = []
......@@ -63,7 +63,7 @@ class DataFromQueue(DataFlow):
"""
self.queue = queue
def get_data(self):
def __iter__(self):
while True:
yield self.queue.get()
......@@ -81,10 +81,10 @@ class DataFromList(RNGDataFlow):
self.lst = lst
self.shuffle = shuffle
def size(self):
def __len__(self):
return len(self.lst)
def get_data(self):
def __iter__(self):
if not self.shuffle:
for k in self.lst:
yield k
......@@ -112,7 +112,7 @@ class DataFromGenerator(DataFlow):
if size is not None:
log_deprecated("DataFromGenerator(size=)", "It doesn't make much sense.", "2018-03-31")
def get_data(self):
def __iter__(self):
# yield from
for dp in self._gen():
yield dp
......@@ -128,9 +128,9 @@ class DataFromIterable(DataFlow):
self._itr = iterable
self._len = len(iterable)
def size(self):
def __len__(self):
return self._len
def get_data(self):
def __iter__(self):
for dp in self._itr:
yield dp
......@@ -58,14 +58,14 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
q = deque(maxlen=INTERVAL)
try:
total = df.size()
total = len(df)
except NotImplementedError:
total = 0
tqdm_args = get_tqdm_kwargs(leave=True, smoothing=0.8)
tqdm_args['bar_format'] = tqdm_args['bar_format'] + "{postfix}"
while True:
with tqdm.trange(total, **tqdm_args) as pbar:
for dp in df.get_data():
for dp in df:
start = time.time()
socket.send(dump_fn(dp), copy=False)
q.append(time.time() - start)
......@@ -117,7 +117,7 @@ class RemoteDataZMQ(DataFlow):
else:
socket.connect(addr)
def get_data(self):
def __iter__(self):
with self._guard:
try:
ctx = zmq.Context()
......
......@@ -20,7 +20,7 @@ __all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer', 'HDF5Seria
def _reset_df_and_get_size(df):
df.reset_state()
try:
sz = df.size()
sz = len(df)
except NotImplementedError:
sz = 0
return sz
......@@ -57,7 +57,7 @@ class LMDBSerializer():
# LMDB transaction is not exception-safe!
# although it has a context manager interface
txn = db.begin(write=True)
for idx, dp in enumerate(df.get_data()):
for idx, dp in enumerate(df):
txn.put(u'{:08}'.format(idx).encode('ascii'), dumps(dp))
pbar.update()
if (idx + 1) % write_frequency == 0:
......@@ -101,7 +101,7 @@ class NumpySerializer():
buffer = []
size = _reset_df_and_get_size(df)
with get_tqdm(total=size) as pbar:
for dp in df.get_data():
for dp in df:
buffer.append(dp)
pbar.update()
np.savez_compressed(path, buffer=np.asarray(buffer, dtype=np.object))
......@@ -135,7 +135,7 @@ class TFRecordSerializer():
size = _reset_df_and_get_size(df)
with tf.python_io.TFRecordWriter(path) as writer, get_tqdm(total=size) as pbar:
for dp in df.get_data():
for dp in df:
writer.write(_dumps(dp))
pbar.update()
......@@ -143,7 +143,7 @@ class TFRecordSerializer():
def load(path, size=None):
"""
Args:
size (int): total number of records. If not provided, the returned dataflow will have no `size()`.
size (int): total number of records. If not provided, the returned dataflow will have no `__len__()`.
It's needed because this metadata is not stored in the TFRecord file.
"""
gen = tf.python_io.tf_record_iterator(path)
......@@ -175,15 +175,16 @@ class HDF5Serializer():
buffer = defaultdict(list)
with get_tqdm(total=size) as pbar:
for dp in df.get_data():
for dp in df:
assert len(dp) == len(data_paths), "Datapoint has {} components!".format(len(dp))
for k, el in zip(data_paths, dp):
buffer[k].append(el)
pbar.update()
with h5py.File(path, 'w') as hf, get_tqdm(total=size) as pbar:
with h5py.File(path, 'w') as hf, get_tqdm(total=len(data_paths)) as pbar:
for data_path in data_paths:
hf.create_dataset(data_path, data=buffer[data_path])
pbar.update()
@staticmethod
def load(path, data_paths, shuffle=True):
......@@ -221,7 +222,7 @@ if __name__ == '__main__':
print(time.time())
df = TFRecordSerializer.load('out.tfrecords', size=1000)
df.reset_state()
for idx, dp in enumerate(df.get_data()):
for idx, dp in enumerate(df):
pass
print("TF Finished, ", idx)
print(time.time())
......@@ -230,7 +231,7 @@ if __name__ == '__main__':
print(time.time())
df = LMDBSerializer.load('out.lmdb')
df.reset_state()
for idx, dp in enumerate(df.get_data()):
for idx, dp in enumerate(df):
pass
print("LMDB Finished, ", idx)
print(time.time())
......@@ -239,7 +240,7 @@ if __name__ == '__main__':
print(time.time())
df = NumpySerializer.load('out.npz')
df.reset_state()
for idx, dp in enumerate(df.get_data()):
for idx, dp in enumerate(df):
pass
print("Numpy Finished, ", idx)
print(time.time())
......@@ -248,7 +249,7 @@ if __name__ == '__main__':
print(time.time())
df = HDF5Serializer.load('out.h5')
df.reset_state()
for idx, dp in enumerate(df.get_data()):
for idx, dp in enumerate(df):
pass
print("HDF5 Finished, ", idx)
print(time.time())
......@@ -77,7 +77,7 @@ class FeedInput(InputSource):
class _FeedCallback(Callback):
def __init__(self, ds, placeholders):
self._ds = ds
self._itr = self._ds.get_data()
self._itr = self._ds.__iter__()
self._placeholders = placeholders
def _before_run(self, _):
......@@ -87,7 +87,7 @@ class FeedInput(InputSource):
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def _reset(self):
self._itr = self._ds.get_data()
self._itr = self._ds.__iter__()
def __init__(self, ds, infinite=True):
"""
......@@ -105,7 +105,7 @@ class FeedInput(InputSource):
self._iter_ds = self.ds
def _size(self):
return self.ds.size()
return len(self.ds)
def _setup(self, inputs):
# placeholders as input are always safe to reuse.
......@@ -175,7 +175,7 @@ class EnqueueThread(ShareSessionThread):
logger.info("{} Exited.".format(self.name))
def reinitialize_dataflow(self):
self._itr = self.dataflow.get_data()
self._itr = self.dataflow.__iter__()
def pause(self):
self._running.clear()
......@@ -207,7 +207,7 @@ class QueueInput(FeedfreeInput):
self._started = False
def _size(self):
return self.ds.size()
return len(self.ds)
def _setup(self, inputs):
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
......@@ -225,7 +225,7 @@ class QueueInput(FeedfreeInput):
def refill_queue(self):
"""
Clear the queue, then call dataflow.get_data() again and fill into the queue.
Clear the queue, then call dataflow.__iter__() again and fill into the queue.
"""
self.thread.pause() # pause enqueue
......@@ -292,7 +292,7 @@ class BatchQueueInput(QueueInput):
self.batch_size = int(batch_size)
def _size(self):
return self.ds.size() // self.batch_size
return len(self.ds) // self.batch_size
def _setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...")
......
......@@ -67,11 +67,11 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
def get_result(self):
self.dataset.reset_state()
try:
sz = self.dataset.size()
sz = len(self.dataset)
except NotImplementedError:
sz = 0
with get_tqdm(total=sz, disable=(sz == 0)) as pbar:
for dp in self.dataset.get_data():
for dp in self.dataset:
res = self.predictor(*dp)
yield res
pbar.update()
......@@ -146,7 +146,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
def get_result(self):
try:
sz = self.dataset.size()
sz = len(self.dataset)
except NotImplementedError:
sz = 0
with get_tqdm(total=sz, disable=(sz == 0)) as pbar:
......
......@@ -131,13 +131,13 @@ class TrainConfig(object):
if steps_per_epoch is None:
try:
if dataflow is not None:
steps_per_epoch = dataflow.size()
steps_per_epoch = len(dataflow)
elif data is not None:
steps_per_epoch = data.size()
else:
raise NotImplementedError()
except NotImplementedError:
logger.error("You must set `TrainConfig(steps_per_epoch)` if data.size() is not available.")
logger.error("You must set `TrainConfig(steps_per_epoch)` if the size of your input is not available.")
raise
else:
steps_per_epoch = int(steps_per_epoch)
......
......@@ -214,7 +214,7 @@ def get_tqdm_kwargs(**kwargs):
return default
def get_tqdm(**kwargs):
""" Similar to :func:`get_tqdm_kwargs`,
but returns the tqdm object directly. """
return tqdm(**get_tqdm_kwargs(**kwargs))
def get_tqdm(*args, **kwargs):
""" Similar to :func:`tqdm.tqdm()`,
but use tensorpack's default options to have consistent style. """
return tqdm(*args, **get_tqdm_kwargs(**kwargs))
......@@ -308,7 +308,7 @@ def dump_dataflow_images(df, index=0, batched=True,
df.reset_state()
cnt = 0
while True:
for dp in df.get_data():
for dp in df:
if not batched:
imgbatch = [dp[index]]
else:
......
......@@ -31,10 +31,10 @@ class SeededFakeDataFlow(DataFlow):
img = np.random.randn(28, 28, 3)
self.cache.append([label, img])
def size(self):
def __len__(self):
return self._size
def get_data(self):
def __iter__(self):
for dp in self.cache:
yield dp
......@@ -52,7 +52,7 @@ class SerializerTest(unittest.TestCase):
ds_actual.reset_state()
ds_expected.reset_state()
for dp_expected, dp_actual in zip(ds_expected.get_data(), ds_actual.get_data()):
for dp_expected, dp_actual in zip(ds_expected.__iter__(), ds_actual.__iter__()):
self.assertEqual(dp_expected[0], dp_actual[0])
self.assertTrue(np.allclose(dp_expected[1], dp_actual[1]))
except ImportError:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment