Commit 7b8728f9 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

make dataflow idiomatic python container objects (fix #869) (#872)

* make dataflow idiomatic python container objects (fix #869)

* use ABCMeta and fix tqdm

* python2 compatible version

* fix

* some more replacement

* rollback changes of _reset_df_and_get_size

* fix unrelated pbar issue

* revert the __len__ changes on InputSource
parent 3476cb43
...@@ -56,7 +56,7 @@ To actually access the datapoints generated by the dataflow, you can use ...@@ -56,7 +56,7 @@ To actually access the datapoints generated by the dataflow, you can use
```python ```python
ds.reset_state() ds.reset_state()
for dp in ds.get_data(): for dp in ds:
print dp[0] # this is an RGB image! print dp[0] # this is an RGB image!
``` ```
This kind of iteration is used behind the scenes to feed data for training. This kind of iteration is used behind the scenes to feed data for training.
......
...@@ -5,11 +5,10 @@ ...@@ -5,11 +5,10 @@
DataFlow is a library to build Python iterators for efficient data loading. DataFlow is a library to build Python iterators for efficient data loading.
**Definition**: A DataFlow is something that has a `get_data()` generator method, **Definition**: A DataFlow is a idiomatic Python container object that has a `__iter__()` generator method, which yields `datapoints` and a `__len__()` method returning the size of the flow.
which yields `datapoints`.
A datapoint is a **list** of Python objects which are called the `components` of a datapoint. A datapoint is a **list** of Python objects which are called the `components` of a datapoint.
**Example**: to train on MNIST dataset, you may need a DataFlow with a `get_data()` method **Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method
that yields datapoints (lists) of two components: that yields datapoints (lists) of two components:
a numpy array of shape (64, 28, 28), and an array of shape (64,). a numpy array of shape (64, 28, 28), and an array of shape (64,).
...@@ -67,7 +66,6 @@ and then use the generator however you like: ...@@ -67,7 +66,6 @@ and then use the generator however you like:
df = SomeDataFlow() df = SomeDataFlow()
df.reset_state() df.reset_state()
generator = df.get_data() for dp in df:
for dp in generator:
# dp is now a list. do whatever # dp is now a list. do whatever
``` ```
...@@ -143,8 +143,8 @@ We can also dump the dataset into one single LMDB file and read it sequentially. ...@@ -143,8 +143,8 @@ We can also dump the dataset into one single LMDB file and read it sequentially.
```python ```python
from tensorpack.dataflow import * from tensorpack.dataflow import *
class BinaryILSVRC12(dataset.ILSVRC12Files): class BinaryILSVRC12(dataset.ILSVRC12Files):
def get_data(self): def __iter__(self):
for fname, label in super(BinaryILSVRC12, self).get_data(): for fname, label in super(BinaryILSVRC12, self).__iter__():
with open(fname, 'rb') as f: with open(fname, 'rb') as f:
jpeg = f.read() jpeg = f.read()
jpeg = np.asarray(bytearray(jpeg), dtype='uint8') jpeg = np.asarray(bytearray(jpeg), dtype='uint8')
......
...@@ -23,7 +23,7 @@ To write more complicated DataFlow, you need to inherit the base `DataFlow` clas ...@@ -23,7 +23,7 @@ To write more complicated DataFlow, you need to inherit the base `DataFlow` clas
Usually, you just need to implement the `get_data()` method which yields a datapoint every time. Usually, you just need to implement the `get_data()` method which yields a datapoint every time.
```python ```python
class MyDataFlow(DataFlow): class MyDataFlow(DataFlow):
def get_data(self): def __iter__(self):
for k in range(100): for k in range(100):
digit = np.random.rand(28, 28) digit = np.random.rand(28, 28)
label = np.random.randint(10) label = np.random.randint(10)
...@@ -56,7 +56,7 @@ class ProcessingDataFlow(DataFlow): ...@@ -56,7 +56,7 @@ class ProcessingDataFlow(DataFlow):
def __init__(self, ds): def __init__(self, ds):
self.ds = ds self.ds = ds
def get_data(self): def __iter__(self):
for datapoint in self.ds.get_data(): for datapoint in self.ds.get_data():
# do something # do something
yield new_datapoint yield new_datapoint
......
...@@ -89,10 +89,10 @@ class RawTIMIT(DataFlow): ...@@ -89,10 +89,10 @@ class RawTIMIT(DataFlow):
assert label in ['phoneme', 'letter'], label assert label in ['phoneme', 'letter'], label
self.label = label self.label = label
def size(self): def __len__(self):
return len(self.filelists) return len(self.filelists)
def get_data(self): def __iter__(self):
for f in self.filelists: for f in self.filelists:
feat = get_feature(f) feat = get_feature(f)
if self.label == 'phoneme': if self.label == 'phoneme':
...@@ -106,12 +106,10 @@ def compute_mean_std(db, fname): ...@@ -106,12 +106,10 @@ def compute_mean_std(db, fname):
ds = LMDBSerializer.load(db, shuffle=False) ds = LMDBSerializer.load(db, shuffle=False)
ds.reset_state() ds.reset_state()
o = OnlineMoments() o = OnlineMoments()
with get_tqdm(total=ds.size()) as bar: for dp in get_tqdm(ds):
for dp in ds.get_data(): feat = dp[0] # len x dim
feat = dp[0] # len x dim for f in feat:
for f in feat: o.feed(f)
o.feed(f)
bar.update()
logger.info("Writing to {} ...".format(fname)) logger.info("Writing to {} ...".format(fname))
with open(fname, 'wb') as f: with open(fname, 'wb') as f:
f.write(serialize.dumps([o.mean, o.std])) f.write(serialize.dumps([o.mean, o.std]))
......
...@@ -39,12 +39,12 @@ class TIMITBatch(ProxyDataFlow): ...@@ -39,12 +39,12 @@ class TIMITBatch(ProxyDataFlow):
self.batch = batch self.batch = batch
self.ds = ds self.ds = ds
def size(self): def __len__(self):
return self.ds.size() // self.batch return len(self.ds) // self.batch
def get_data(self): def __iter__(self):
itr = self.ds.get_data() itr = self.ds.__iter__()
for _ in range(self.size()): for _ in range(self.__len__()):
feats = [] feats = []
labs = [] labs = []
for b in range(self.batch): for b in range(self.batch):
......
...@@ -58,10 +58,10 @@ class CharRNNData(RNGDataFlow): ...@@ -58,10 +58,10 @@ class CharRNNData(RNGDataFlow):
self.whole_seq = np.array([self.char2idx[c] for c in data], dtype='int32') self.whole_seq = np.array([self.char2idx[c] for c in data], dtype='int32')
logger.info("Corpus loaded. Vocab size: {}".format(self.vocab_size)) logger.info("Corpus loaded. Vocab size: {}".format(self.vocab_size))
def size(self): def __len__(self):
return self._size return self._size
def get_data(self): def __iter__(self):
random_starts = self.rng.randint( random_starts = self.rng.randint(
0, self.whole_seq.shape[0] - self.seq_length - 1, (self._size,)) 0, self.whole_seq.shape[0] - self.seq_length - 1, (self._size,))
for st in random_starts: for st in random_starts:
......
...@@ -241,7 +241,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -241,7 +241,7 @@ class ExpReplay(DataFlow, Callback):
return [state, action, reward, isOver] return [state, action, reward, isOver]
# DataFlow method: # DataFlow method:
def get_data(self): def __iter__(self):
# wait for memory to be initialized # wait for memory to be initialized
self._init_memory_flag.wait() self._init_memory_flag.wait()
......
...@@ -14,8 +14,8 @@ class DisturbLabel(ProxyDataFlow, RNGDataFlow): ...@@ -14,8 +14,8 @@ class DisturbLabel(ProxyDataFlow, RNGDataFlow):
RNGDataFlow.reset_state(self) RNGDataFlow.reset_state(self)
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
def get_data(self): def __iter__(self):
for dp in self.ds.get_data(): for dp in self.ds:
img, l = dp img, l = dp
if self.rng.rand() < self.prob: if self.rng.rand() < self.prob:
l = self.rng.choice(10) l = self.rng.choice(10)
......
...@@ -214,8 +214,8 @@ class ThetaImages(ProxyDataFlow, RNGDataFlow): ...@@ -214,8 +214,8 @@ class ThetaImages(ProxyDataFlow, RNGDataFlow):
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self) RNGDataFlow.reset_state(self)
def get_data(self): def __iter__(self):
for image, label in self.ds.get_data(): for image, label in self.ds:
theta = self.rng.uniform(0, 2 * np.pi) theta = self.rng.uniform(0, 2 * np.pi)
filtered_image, gt_filter = ThetaImages.filter_with_theta(image, theta) filtered_image, gt_filter = ThetaImages.filter_with_theta(image, theta)
yield [theta, image, filtered_image, gt_filter] yield [theta, image, filtered_image, gt_filter]
...@@ -245,7 +245,7 @@ def get_config(): ...@@ -245,7 +245,7 @@ def get_config():
OnlineTensorboardExport() OnlineTensorboardExport()
], ],
model=Model(), model=Model(),
steps_per_epoch=dataset_train.size(), steps_per_epoch=len(dataset_train),
max_epoch=50, max_epoch=50,
) )
......
...@@ -15,10 +15,10 @@ class DataFromListOfDict(RNGDataFlow): ...@@ -15,10 +15,10 @@ class DataFromListOfDict(RNGDataFlow):
self._shuffle = shuffle self._shuffle = shuffle
self._size = len(lst) self._size = len(lst)
def size(self): def __len__(self):
return self._size return self._size
def get_data(self): def __iter__(self):
if self._shuffle: if self._shuffle:
self.rng.shuffle(self._lst) self.rng.shuffle(self._lst)
for dic in self._lst: for dic in self._lst:
......
...@@ -412,5 +412,5 @@ if __name__ == '__main__': ...@@ -412,5 +412,5 @@ if __name__ == '__main__':
ds = PrintData(ds, 100) ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start() TestDataSpeed(ds, 50000).start()
ds.reset_state() ds.reset_state()
for k in ds.get_data(): for k in ds:
pass pass
...@@ -109,7 +109,7 @@ def eval_coco(df, detect_func, tqdm_bar=None): ...@@ -109,7 +109,7 @@ def eval_coco(df, detect_func, tqdm_bar=None):
if tqdm_bar is None: if tqdm_bar is None:
tqdm_bar = stack.enter_context( tqdm_bar = stack.enter_context(
tqdm.tqdm(total=df.size(), **get_tqdm_kwargs())) tqdm.tqdm(total=df.size(), **get_tqdm_kwargs()))
for img, img_id in df.get_data(): for img, img_id in df:
results = detect_func(img) results = detect_func(img)
for r in results: for r in results:
box = r.box box = r.box
......
...@@ -348,7 +348,7 @@ def visualize(model, model_path, nr_visualize=100, output_dir='output'): ...@@ -348,7 +348,7 @@ def visualize(model, model_path, nr_visualize=100, output_dir='output'):
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
utils.fs.mkdir_p(output_dir) utils.fs.mkdir_p(output_dir)
with tqdm.tqdm(total=nr_visualize) as pbar: with tqdm.tqdm(total=nr_visualize) as pbar:
for idx, dp in itertools.islice(enumerate(df.get_data()), nr_visualize): for idx, dp in itertools.islice(enumerate(df), nr_visualize):
img = dp[0] img = dp[0]
if cfg.MODE_MASK: if cfg.MODE_MASK:
gt_boxes, gt_labels, gt_masks = dp[-3:] gt_boxes, gt_labels, gt_masks = dp[-3:]
......
...@@ -193,7 +193,7 @@ class VisualizeTestSet(Callback): ...@@ -193,7 +193,7 @@ class VisualizeTestSet(Callback):
def _trigger(self): def _trigger(self):
idx = 0 idx = 0
for iA, iB in self.val_ds.get_data(): for iA, iB in self.val_ds:
vizA, vizB = self.pred(iA, iB) vizA, vizB = self.pred(iA, iB)
self.trainer.monitors.put_image('testA-{}'.format(idx), vizA) self.trainer.monitors.put_image('testA-{}'.format(idx), vizA)
self.trainer.monitors.put_image('testB-{}'.format(idx), vizB) self.trainer.monitors.put_image('testB-{}'.format(idx), vizB)
...@@ -223,6 +223,6 @@ if __name__ == '__main__': ...@@ -223,6 +223,6 @@ if __name__ == '__main__':
PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3), PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3),
], ],
max_epoch=195, max_epoch=195,
steps_per_epoch=data.size(), steps_per_epoch=len(data),
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
...@@ -191,6 +191,6 @@ class RandomZData(DataFlow): ...@@ -191,6 +191,6 @@ class RandomZData(DataFlow):
super(RandomZData, self).__init__() super(RandomZData, self).__init__()
self.shape = shape self.shape = shape
def get_data(self): def __iter__(self):
while True: while True:
yield [np.random.uniform(-1, 1, size=self.shape)] yield [np.random.uniform(-1, 1, size=self.shape)]
...@@ -223,7 +223,7 @@ if __name__ == '__main__': ...@@ -223,7 +223,7 @@ if __name__ == '__main__':
PeriodicTrigger(ModelSaver(), every_k_epochs=3), PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
], ],
steps_per_epoch=data.size(), steps_per_epoch=len(data),
max_epoch=300, max_epoch=300,
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
...@@ -237,7 +237,7 @@ def get_data(name): ...@@ -237,7 +237,7 @@ def get_data(name):
def view_data(): def view_data():
ds = RepeatedData(get_data('train'), -1) ds = RepeatedData(get_data('train'), -1)
ds.reset_state() ds.reset_state()
for ims, edgemaps in ds.get_data(): for ims, edgemaps in ds:
for im, edgemap in zip(ims, edgemaps): for im, edgemap in zip(ims, edgemaps):
assert im.shape[0] % 16 == 0 and im.shape[1] % 16 == 0, im.shape assert im.shape[0] % 16 == 0 and im.shape[1] % 16 == 0, im.shape
cv2.imshow("im", im / 255.0) cv2.imshow("im", im / 255.0)
...@@ -249,7 +249,7 @@ def view_data(): ...@@ -249,7 +249,7 @@ def view_data():
def get_config(): def get_config():
logger.auto_set_dir() logger.auto_set_dir()
dataset_train = get_data('train') dataset_train = get_data('train')
steps_per_epoch = dataset_train.size() * 40 steps_per_epoch = len(dataset_train) * 40
dataset_val = get_data('val') dataset_val = get_data('val')
return TrainConfig( return TrainConfig(
......
...@@ -58,7 +58,7 @@ class SintelData(DataFlow): ...@@ -58,7 +58,7 @@ class SintelData(DataFlow):
def size(self): def size(self):
return len(self.flows) return len(self.flows)
def get_data(self): def __iter__(self):
for flow_path in self.flows: for flow_path in self.flows:
input_path = flow_path.replace( input_path = flow_path.replace(
self.path_prefix, os.path.join(self.data_path, 'clean')) self.path_prefix, os.path.join(self.data_path, 'clean'))
...@@ -95,7 +95,7 @@ def inference(model, model_path, sintel_path): ...@@ -95,7 +95,7 @@ def inference(model, model_path, sintel_path):
ds.reset_state() ds.reset_state()
# look at shape information (all images in Sintel has the same shape) # look at shape information (all images in Sintel has the same shape)
h, w = next(ds.get_data())[0].shape[2:] h, w = next(ds.__iter__())[0].shape[2:]
pred = PredictConfig( pred = PredictConfig(
model=model(height=h, width=w), model=model(height=h, width=w),
......
...@@ -138,8 +138,6 @@ if __name__ == '__main__': ...@@ -138,8 +138,6 @@ if __name__ == '__main__':
dataset_train = get_data('train', args.mixup, args.alpha) dataset_train = get_data('train', args.mixup, args.alpha)
dataset_test = get_data('test', args.mixup, args.alpha) dataset_test = get_data('test', args.mixup, args.alpha)
steps_per_epoch = dataset_train.size()
config = TrainConfig( config = TrainConfig(
model=ResNet_Cifar(), model=ResNet_Cifar(),
data=QueueInput(dataset_train), data=QueueInput(dataset_train),
...@@ -150,7 +148,7 @@ if __name__ == '__main__': ...@@ -150,7 +148,7 @@ if __name__ == '__main__':
ScheduledHyperParamSetter('learning_rate', LR_SCHEDULE) ScheduledHyperParamSetter('learning_rate', LR_SCHEDULE)
], ],
max_epoch=200, max_epoch=200,
steps_per_epoch=steps_per_epoch, steps_per_epoch=len(dataset_train),
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
launch_train_with_config(config, SimpleTrainer()) launch_train_with_config(config, SimpleTrainer())
...@@ -41,7 +41,7 @@ class MnistPairs(dataset.Mnist): ...@@ -41,7 +41,7 @@ class MnistPairs(dataset.Mnist):
idx = self.rng.randint(len(self.data_dict[label])) idx = self.rng.randint(len(self.data_dict[label]))
return self.data_dict[label][idx].astype(np.float32) return self.data_dict[label][idx].astype(np.float32)
def get_data(self): def __iter__(self):
while True: while True:
y = self.rng.randint(2) y = self.rng.randint(2)
if y == 0: if y == 0:
...@@ -54,7 +54,7 @@ class MnistPairs(dataset.Mnist): ...@@ -54,7 +54,7 @@ class MnistPairs(dataset.Mnist):
class MnistTriplets(MnistPairs): class MnistTriplets(MnistPairs):
def get_data(self): def __iter__(self):
while True: while True:
pick_label, pick_other = self.rng.choice(10, size=2, replace=False) pick_label, pick_other = self.rng.choice(10, size=2, replace=False)
yield [self.pick(pick_label), self.pick(pick_label), self.pick(pick_other)] yield [self.pick(pick_label), self.pick(pick_label), self.pick(pick_other)]
...@@ -379,7 +379,7 @@ def visualize(model_path, model, algo_name): ...@@ -379,7 +379,7 @@ def visualize(model_path, model, algo_name):
ds = get_test_data() ds = get_test_data()
ds.reset_state() ds.reset_state()
for offset, dp in enumerate(ds.get_data()): for offset, dp in enumerate(ds):
digit, label = dp digit, label = dp
prediction = pred(digit)[0] prediction = pred(digit)[0]
embed[offset * BATCH_SIZE:offset * BATCH_SIZE + BATCH_SIZE, ...] = prediction embed[offset * BATCH_SIZE:offset * BATCH_SIZE + BATCH_SIZE, ...] = prediction
......
...@@ -222,7 +222,7 @@ def view_warp(modelpath): ...@@ -222,7 +222,7 @@ def view_warp(modelpath):
ds = get_data(False) ds = get_data(False)
ds.reset_state() ds.reset_state()
for k in ds.get_data(): for k in ds:
img, label = k img, label = k
outputs, affine1, affine2 = pred(img) outputs, affine1, affine2 = pred(img)
for idx, viz in enumerate(outputs): for idx, viz in enumerate(outputs):
...@@ -238,7 +238,7 @@ def get_config(): ...@@ -238,7 +238,7 @@ def get_config():
logger.auto_set_dir() logger.auto_set_dir()
dataset_train, dataset_test = get_data(True), get_data(False) dataset_train, dataset_test = get_data(True), get_data(False)
steps_per_epoch = dataset_train.size() * 5 steps_per_epoch = len(dataset_train) * 5
return TrainConfig( return TrainConfig(
model=Model(), model=Model(),
......
...@@ -34,7 +34,7 @@ class ImageDataFromZIPFile(RNGDataFlow): ...@@ -34,7 +34,7 @@ class ImageDataFromZIPFile(RNGDataFlow):
def size(self): def size(self):
return len(self.archivefiles) return len(self.archivefiles)
def get_data(self): def __iter__(self):
if self.shuffle: if self.shuffle:
self.rng.shuffle(self.archivefiles) self.rng.shuffle(self.archivefiles)
for archive in self.archivefiles: for archive in self.archivefiles:
...@@ -111,6 +111,6 @@ if __name__ == '__main__': ...@@ -111,6 +111,6 @@ if __name__ == '__main__':
LMDBSerializer.save(ds, args.lmdb) LMDBSerializer.save(ds, args.lmdb)
if args.debug: if args.debug:
ds.reset_state() ds.reset_state()
for i in ds.get_data(): for i in ds:
cv2.imshow('example', i[0]) cv2.imshow('example', i[0])
cv2.waitKey(0) cv2.waitKey(0)
...@@ -298,6 +298,6 @@ if __name__ == '__main__': ...@@ -298,6 +298,6 @@ if __name__ == '__main__':
ModelSaver(keep_checkpoint_every_n_hours=2) ModelSaver(keep_checkpoint_every_n_hours=2)
], ],
session_init=session_init, session_init=session_init,
steps_per_epoch=data.size() // 4, steps_per_epoch=len(data) // 4,
max_epoch=300 max_epoch=300
) )
...@@ -104,8 +104,8 @@ if __name__ == '__main__': ...@@ -104,8 +104,8 @@ if __name__ == '__main__':
dataset_train, dataset_test = get_data() dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch. # How many iterations you want in each epoch.
# This (data.size()) is the default value. # This len(data) is the default value.
steps_per_epoch = dataset_train.size() steps_per_epoch = len(dataset_train)
# get the config which contains everything necessary in a training # get the config which contains everything necessary in a training
config = TrainConfig( config = TrainConfig(
......
...@@ -108,8 +108,8 @@ if __name__ == '__main__': ...@@ -108,8 +108,8 @@ if __name__ == '__main__':
dataset_train, dataset_test = get_data() dataset_train, dataset_test = get_data()
# How many iterations you want in each epoch. # How many iterations you want in each epoch.
# This (data.size()) is the default value. # This len(data) is the default value.
steps_per_epoch = dataset_train.size() steps_per_epoch = len(dataset_train)
# get the config which contains everything necessary in a training # get the config which contains everything necessary in a training
config = TrainConfig( config = TrainConfig(
......
...@@ -132,7 +132,7 @@ if __name__ == '__main__': ...@@ -132,7 +132,7 @@ if __name__ == '__main__':
InferenceRunner( InferenceRunner(
dataset_test, ScalarStats(['cross_entropy_loss', 'accuracy'])), dataset_test, ScalarStats(['cross_entropy_loss', 'accuracy'])),
], ],
steps_per_epoch=dataset_train.size(), steps_per_epoch=len(dataset_train),
max_epoch=100, max_epoch=100,
) )
......
...@@ -55,7 +55,7 @@ def get_config(): ...@@ -55,7 +55,7 @@ def get_config():
ModelSaver(), ModelSaver(),
InferenceRunner(ds_test, [ScalarStats('total_costs')]), InferenceRunner(ds_test, [ScalarStats('total_costs')]),
], ],
steps_per_epoch=ds_train.size(), steps_per_epoch=len(ds_train),
max_epoch=100, max_epoch=100,
) )
......
...@@ -69,7 +69,7 @@ if __name__ == '__main__': ...@@ -69,7 +69,7 @@ if __name__ == '__main__':
) )
M.fit( M.fit(
validation_data=dataset_test, validation_data=dataset_test,
steps_per_epoch=dataset_train.size(), steps_per_epoch=len(dataset_train),
callbacks=[ callbacks=[
ModelSaver() ModelSaver()
] ]
......
...@@ -51,7 +51,7 @@ def _inference_context(): ...@@ -51,7 +51,7 @@ def _inference_context():
yield yield
except (StopIteration, tf.errors.CancelledError): except (StopIteration, tf.errors.CancelledError):
logger.error( logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg) "[InferenceRunner] input stopped before reaching its __len__()! " + msg)
raise raise
except tf.errors.OutOfRangeError: # tf.data reaches an end except tf.errors.OutOfRangeError: # tf.data reaches an end
pass pass
...@@ -63,7 +63,7 @@ class InferenceRunnerBase(Callback): ...@@ -63,7 +63,7 @@ class InferenceRunnerBase(Callback):
Note: Note:
1. InferenceRunner will use `input.size()` to determine 1. InferenceRunner will use `input.size()` to determine
how much iterations to run, so you're responsible to ensure that how much iterations to run, so you're responsible to ensure that
`size()` is reasonable. `input.size()` is reasonable.
2. Only works with instances of `TowerTrainer`. 2. Only works with instances of `TowerTrainer`.
""" """
def __init__(self, input, infs): def __init__(self, input, infs):
......
...@@ -39,12 +39,30 @@ class DataFlowReentrantGuard(object): ...@@ -39,12 +39,30 @@ class DataFlowReentrantGuard(object):
return False return False
@six.add_metaclass(ABCMeta) # NOTE: we cannot use six here
class DataFlowMeta(ABCMeta):
"""
DataFlow uses "__iter__()" and "__len__()" instead of
"get_data()" and "size()". This add back-compatibility.
"""
def __new__(mcls, name, bases, namespace, **kwargs):
def hot_patch(required, existing):
if required not in namespace and existing in namespace:
namespace[required] = namespace[existing]
hot_patch('__iter__', 'get_data')
hot_patch('__len__', 'size')
return ABCMeta.__new__(mcls, name, bases, namespace, **kwargs)
@six.add_metaclass(DataFlowMeta)
class DataFlow(object): class DataFlow(object):
""" Base class for all DataFlow """ """ Base class for all DataFlow """
@abstractmethod @abstractmethod
def get_data(self): def __iter__(self):
""" """
The method to generate datapoints. The method to generate datapoints.
...@@ -52,7 +70,10 @@ class DataFlow(object): ...@@ -52,7 +70,10 @@ class DataFlow(object):
list: The datapoint, i.e. list of components. list: The datapoint, i.e. list of components.
""" """
def size(self): def get_data(self):
return self.__iter__()
def __len__(self):
""" """
Returns: Returns:
int: size of this data flow. int: size of this data flow.
...@@ -62,6 +83,9 @@ class DataFlow(object): ...@@ -62,6 +83,9 @@ class DataFlow(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def size(self):
return self.__len__()
def reset_state(self): def reset_state(self):
""" """
Reset state of the dataflow. Reset state of the dataflow.
...@@ -102,8 +126,8 @@ class ProxyDataFlow(DataFlow): ...@@ -102,8 +126,8 @@ class ProxyDataFlow(DataFlow):
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
def size(self): def __len__(self):
return self.ds.size() return self.ds.__len__()
def get_data(self): def __iter__(self):
return self.ds.get_data() return self.ds.__iter__()
...@@ -36,10 +36,10 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -36,10 +36,10 @@ class TestDataSpeed(ProxyDataFlow):
self.test_size = int(size) self.test_size = int(size)
self.warmup = int(warmup) self.warmup = int(warmup)
def get_data(self): def __iter__(self):
""" Will run testing at the beginning, then produce data normally. """ """ Will run testing at the beginning, then produce data normally. """
self.start_test() self.start_test()
for dp in self.ds.get_data(): for dp in self.ds:
yield dp yield dp
def start_test(self): def start_test(self):
...@@ -47,7 +47,7 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -47,7 +47,7 @@ class TestDataSpeed(ProxyDataFlow):
Start testing with a progress bar. Start testing with a progress bar.
""" """
self.ds.reset_state() self.ds.reset_state()
itr = self.ds.get_data() itr = self.ds.__iter__()
if self.warmup: if self.warmup:
for d in tqdm.trange(self.warmup, **get_tqdm_kwargs()): for d in tqdm.trange(self.warmup, **get_tqdm_kwargs()):
next(itr) next(itr)
...@@ -85,35 +85,35 @@ class BatchData(ProxyDataFlow): ...@@ -85,35 +85,35 @@ class BatchData(ProxyDataFlow):
enough to form a batch, whether or not to also produce the remaining enough to form a batch, whether or not to also produce the remaining
data as a smaller batch. data as a smaller batch.
If set to False, all produced datapoints are guaranteed to have the same batch size. If set to False, all produced datapoints are guaranteed to have the same batch size.
If set to True, `ds.size()` must be accurate. If set to True, `len(ds)` must be accurate.
use_list (bool): if True, each component will contain a list use_list (bool): if True, each component will contain a list
of datapoints instead of an numpy array of an extra dimension. of datapoints instead of an numpy array of an extra dimension.
""" """
super(BatchData, self).__init__(ds) super(BatchData, self).__init__(ds)
if not remainder: if not remainder:
try: try:
assert batch_size <= ds.size() assert batch_size <= len(ds)
except NotImplementedError: except NotImplementedError:
pass pass
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
self.remainder = remainder self.remainder = remainder
self.use_list = use_list self.use_list = use_list
def size(self): def __len__(self):
ds_size = self.ds.size() ds_size = len(self.ds)
div = ds_size // self.batch_size div = ds_size // self.batch_size
rem = ds_size % self.batch_size rem = ds_size % self.batch_size
if rem == 0: if rem == 0:
return div return div
return div + int(self.remainder) return div + int(self.remainder)
def get_data(self): def __iter__(self):
""" """
Yields: Yields:
Batched data by stacking each component on an extra 0th dimension. Batched data by stacking each component on an extra 0th dimension.
""" """
holder = [] holder = []
for data in self.ds.get_data(): for data in self.ds:
holder.append(data) holder.append(data)
if len(holder) == self.batch_size: if len(holder) == self.batch_size:
yield BatchData._aggregate_batch(holder, self.use_list) yield BatchData._aggregate_batch(holder, self.use_list)
...@@ -184,9 +184,9 @@ class BatchDataByShape(BatchData): ...@@ -184,9 +184,9 @@ class BatchDataByShape(BatchData):
super(BatchDataByShape, self).reset_state() super(BatchDataByShape, self).reset_state()
self.holder = defaultdict(list) self.holder = defaultdict(list)
def get_data(self): def __iter__(self):
with self._guard: with self._guard:
for dp in self.ds.get_data(): for dp in self.ds:
shp = dp[self.idx].shape shp = dp[self.idx].shape
holder = self.holder[shp] holder = self.holder[shp]
holder.append(dp) holder.append(dp)
...@@ -204,7 +204,7 @@ class FixedSizeData(ProxyDataFlow): ...@@ -204,7 +204,7 @@ class FixedSizeData(ProxyDataFlow):
ds (DataFlow): input dataflow ds (DataFlow): input dataflow
size (int): size size (int): size
keep_state (bool): keep the iterator state of ``ds`` keep_state (bool): keep the iterator state of ``ds``
between calls to :meth:`get_data()`, so that the between calls to :meth:`__iter__()`, so that the
next call will continue the previous iteration over ``ds``, next call will continue the previous iteration over ``ds``,
instead of reinitializing an iterator. instead of reinitializing an iterator.
...@@ -223,23 +223,23 @@ class FixedSizeData(ProxyDataFlow): ...@@ -223,23 +223,23 @@ class FixedSizeData(ProxyDataFlow):
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
self._keep = keep_state self._keep = keep_state
def size(self): def __len__(self):
return self._size return self._size
def reset_state(self): def reset_state(self):
super(FixedSizeData, self).reset_state() super(FixedSizeData, self).reset_state()
self.itr = self.ds.get_data() self.itr = self.ds.__iter__()
def get_data(self): def __iter__(self):
with self._guard: with self._guard:
if self.itr is None: if self.itr is None:
self.itr = self.ds.get_data() self.itr = self.ds.__iter__()
cnt = 0 cnt = 0
while True: while True:
try: try:
dp = next(self.itr) dp = next(self.itr)
except StopIteration: except StopIteration:
self.itr = self.ds.get_data() self.itr = self.ds.__iter__()
dp = next(self.itr) dp = next(self.itr)
cnt += 1 cnt += 1
...@@ -257,7 +257,7 @@ class MapData(ProxyDataFlow): ...@@ -257,7 +257,7 @@ class MapData(ProxyDataFlow):
Note: Note:
1. Please make sure func doesn't modify the components 1. Please make sure func doesn't modify the components
unless you're certain it's safe. unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect. 2. If you discard some datapoints, ``len(ds)`` will be incorrect.
""" """
def __init__(self, ds, func): def __init__(self, ds, func):
...@@ -270,8 +270,8 @@ class MapData(ProxyDataFlow): ...@@ -270,8 +270,8 @@ class MapData(ProxyDataFlow):
super(MapData, self).__init__(ds) super(MapData, self).__init__(ds)
self.func = func self.func = func
def get_data(self): def __iter__(self):
for dp in self.ds.get_data(): for dp in self.ds:
ret = self.func(copy(dp)) # shallow copy the list ret = self.func(copy(dp)) # shallow copy the list
if ret is not None: if ret is not None:
yield ret yield ret
...@@ -285,7 +285,7 @@ class MapDataComponent(MapData): ...@@ -285,7 +285,7 @@ class MapDataComponent(MapData):
1. This dataflow itself doesn't modify the datapoints. 1. This dataflow itself doesn't modify the datapoints.
But please make sure func doesn't modify the components But please make sure func doesn't modify the components
unless you're certain it's safe. unless you're certain it's safe.
2. If you discard some datapoints, ``ds.size()`` will be incorrect. 2. If you discard some datapoints, ``len(ds)`` will be incorrect.
""" """
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
""" """
...@@ -324,23 +324,23 @@ class RepeatedData(ProxyDataFlow): ...@@ -324,23 +324,23 @@ class RepeatedData(ProxyDataFlow):
self.nr = nr self.nr = nr
super(RepeatedData, self).__init__(ds) super(RepeatedData, self).__init__(ds)
def size(self): def __len__(self):
""" """
Raises: Raises:
:class:`ValueError` when nr == -1. :class:`ValueError` when nr == -1.
""" """
if self.nr == -1: if self.nr == -1:
raise NotImplementedError("size() is unavailable for infinite dataflow") raise NotImplementedError("__len__() is unavailable for infinite dataflow")
return self.ds.size() * self.nr return len(self.ds) * self.nr
def get_data(self): def __iter__(self):
if self.nr == -1: if self.nr == -1:
while True: while True:
for dp in self.ds.get_data(): for dp in self.ds:
yield dp yield dp
else: else:
for _ in range(self.nr): for _ in range(self.nr):
for dp in self.ds.get_data(): for dp in self.ds:
yield dp yield dp
...@@ -360,11 +360,11 @@ class RepeatedDataPoint(ProxyDataFlow): ...@@ -360,11 +360,11 @@ class RepeatedDataPoint(ProxyDataFlow):
assert self.nr >= 1, self.nr assert self.nr >= 1, self.nr
super(RepeatedDataPoint, self).__init__(ds) super(RepeatedDataPoint, self).__init__(ds)
def size(self): def __len__(self):
return self.ds.size() * self.nr return len(self.ds) * self.nr
def get_data(self): def __iter__(self):
for dp in self.ds.get_data(): for dp in self.ds:
for _ in range(self.nr): for _ in range(self.nr):
yield dp yield dp
...@@ -397,8 +397,8 @@ class RandomChooseData(RNGDataFlow): ...@@ -397,8 +397,8 @@ class RandomChooseData(RNGDataFlow):
else: else:
d.reset_state() d.reset_state()
def get_data(self): def __iter__(self):
itrs = [v[0].get_data() for v in self.df_lists] itrs = [v[0].__iter__() for v in self.df_lists]
probs = np.array([v[1] for v in self.df_lists]) probs = np.array([v[1] for v in self.df_lists])
try: try:
while True: while True:
...@@ -410,35 +410,35 @@ class RandomChooseData(RNGDataFlow): ...@@ -410,35 +410,35 @@ class RandomChooseData(RNGDataFlow):
class RandomMixData(RNGDataFlow): class RandomMixData(RNGDataFlow):
""" """
Perfectly mix datapoints from several DataFlow using their :meth:`size()`. Perfectly mix datapoints from several DataFlow using their
Will stop when all DataFlow exhausted. :meth:`__len__()`. Will stop when all DataFlow exhausted.
""" """
def __init__(self, df_lists): def __init__(self, df_lists):
""" """
Args: Args:
df_lists (list): a list of DataFlow. df_lists (list): a list of DataFlow.
All DataFlow must implement ``size()``. All DataFlow must implement ``__len__()``.
""" """
super(RandomMixData, self).__init__() super(RandomMixData, self).__init__()
self.df_lists = df_lists self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists] self.sizes = [len(k) for k in self.df_lists]
def reset_state(self): def reset_state(self):
super(RandomMixData, self).reset_state() super(RandomMixData, self).reset_state()
for d in self.df_lists: for d in self.df_lists:
d.reset_state() d.reset_state()
def size(self): def __len__(self):
return sum(self.sizes) return sum(self.sizes)
def get_data(self): def __iter__(self):
sums = np.cumsum(self.sizes) sums = np.cumsum(self.sizes)
idxs = np.arange(self.size()) idxs = np.arange(self.__len__())
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
idxs = np.array(list(map( idxs = np.array(list(map(
lambda x: np.searchsorted(sums, x, 'right'), idxs))) lambda x: np.searchsorted(sums, x, 'right'), idxs)))
itrs = [k.get_data() for k in self.df_lists] itrs = [k.__iter__() for k in self.df_lists]
assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs) - 1) assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs) - 1)
for k in idxs: for k in idxs:
yield next(itrs[k]) yield next(itrs[k])
...@@ -463,12 +463,12 @@ class ConcatData(DataFlow): ...@@ -463,12 +463,12 @@ class ConcatData(DataFlow):
for d in self.df_lists: for d in self.df_lists:
d.reset_state() d.reset_state()
def size(self): def __len__(self):
return sum([x.size() for x in self.df_lists]) return sum([len(x) for x in self.df_lists])
def get_data(self): def __iter__(self):
for d in self.df_lists: for d in self.df_lists:
for dp in d.get_data(): for dp in d.__iter__():
yield dp yield dp
...@@ -492,15 +492,15 @@ class JoinData(DataFlow): ...@@ -492,15 +492,15 @@ class JoinData(DataFlow):
When these dataflows have different sizes, JoinData will stop when any When these dataflows have different sizes, JoinData will stop when any
of them is exhausted. of them is exhausted.
The list could contain the same DataFlow instance more than once, The list could contain the same DataFlow instance more than once,
but note that `get_data` will then also be called many times. but note that `__iter__` will then also be called many times.
""" """
self.df_lists = df_lists self.df_lists = df_lists
try: try:
self._size = self.df_lists[0].size() self._size = len(self.df_lists[0])
for d in self.df_lists: for d in self.df_lists:
assert d.size() == self._size, \ assert len(d) == self._size, \
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size) "All DataFlow must have the same size! {} != {}".format(len(d), self._size)
except Exception: except Exception:
logger.info("[JoinData] Size check failed for the list of dataflow to be joined!") logger.info("[JoinData] Size check failed for the list of dataflow to be joined!")
...@@ -508,14 +508,14 @@ class JoinData(DataFlow): ...@@ -508,14 +508,14 @@ class JoinData(DataFlow):
for d in set(self.df_lists): for d in set(self.df_lists):
d.reset_state() d.reset_state()
def size(self): def __len__(self):
""" """
Return the minimum size among all. Return the minimum size among all.
""" """
return min([k.size() for k in self.df_lists]) return min([len(k) for k in self.df_lists])
def get_data(self): def __iter__(self):
itrs = [k.get_data() for k in self.df_lists] itrs = [k.__iter__() for k in self.df_lists]
try: try:
while True: while True:
dp = [] dp = []
...@@ -574,7 +574,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -574,7 +574,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def reset_state(self): def reset_state(self):
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self) RNGDataFlow.reset_state(self)
self.ds_itr = RepeatedData(self.ds, -1).get_data() self.ds_itr = RepeatedData(self.ds, -1).__iter__()
self.current_cnt = 0 self.current_cnt = 0
def _add_data(self): def _add_data(self):
...@@ -582,13 +582,13 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -582,13 +582,13 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
for _ in range(self.nr_reuse): for _ in range(self.nr_reuse):
self.q.append(dp) self.q.append(dp)
def get_data(self): def __iter__(self):
with self._guard: with self._guard:
# fill queue # fill queue
while self.q.maxlen > len(self.q): while self.q.maxlen > len(self.q):
self._add_data() self._add_data()
sz = self.size() sz = self.__len__()
cnt = 0 cnt = 0
while True: while True:
self.rng.shuffle(self.q) self.rng.shuffle(self.q)
...@@ -626,7 +626,7 @@ class CacheData(ProxyDataFlow): ...@@ -626,7 +626,7 @@ class CacheData(ProxyDataFlow):
self.rng = get_rng(self) self.rng = get_rng(self)
self.buffer = [] self.buffer = []
def get_data(self): def __iter__(self):
with self._guard: with self._guard:
if len(self.buffer): if len(self.buffer):
if self.shuffle: if self.shuffle:
...@@ -634,7 +634,7 @@ class CacheData(ProxyDataFlow): ...@@ -634,7 +634,7 @@ class CacheData(ProxyDataFlow):
for dp in self.buffer: for dp in self.buffer:
yield dp yield dp
else: else:
for dp in self.ds.get_data(): for dp in self.ds:
yield dp yield dp
self.buffer.append(dp) self.buffer.append(dp)
...@@ -648,12 +648,12 @@ class PrintData(ProxyDataFlow): ...@@ -648,12 +648,12 @@ class PrintData(ProxyDataFlow):
.. code-block:: python .. code-block:: python
def get_data(): def __iter__():
ds = SomeDataSource('path/to/lmdb') ds = SomeDataSource('path/to/lmdb')
ds = SomeInscrutableMappings(ds) ds = SomeInscrutableMappings(ds)
ds = PrintData(ds, num=2, max_list=2) ds = PrintData(ds, num=2, max_list=2)
return ds return ds
ds = get_data() ds = __iter__()
The output looks like: The output looks like:
...@@ -766,8 +766,8 @@ class PrintData(ProxyDataFlow): ...@@ -766,8 +766,8 @@ class PrintData(ProxyDataFlow):
msg.append(self._analyze_input_data(entry, k, max_depth=self.max_depth, max_list=self.max_list)) msg.append(self._analyze_input_data(entry, k, max_depth=self.max_depth, max_list=self.max_list))
return u'\n'.join(msg) return u'\n'.join(msg)
def get_data(self): def __iter__(self):
for dp in self.ds.get_data(): for dp in self.ds:
# it is important to place this here! otherwise it mixes the output of multiple PrintData # it is important to place this here! otherwise it mixes the output of multiple PrintData
if self.cnt == 0: if self.cnt == 0:
label = ' (%s)' % self.name if self.name is not None else "" label = ' (%s)' % self.name if self.name is not None else ""
......
...@@ -78,10 +78,10 @@ class BSDS500(RNGDataFlow): ...@@ -78,10 +78,10 @@ class BSDS500(RNGDataFlow):
self.data[idx] = im self.data[idx] = im
self.label[idx] = gt self.label[idx] = gt
def size(self): def __len__(self):
return self.data.shape[0] return self.data.shape[0]
def get_data(self): def __iter__(self):
idxs = np.arange(self.data.shape[0]) idxs = np.arange(self.data.shape[0])
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
...@@ -99,6 +99,6 @@ except ImportError: ...@@ -99,6 +99,6 @@ except ImportError:
if __name__ == '__main__': if __name__ == '__main__':
a = BSDS500('val') a = BSDS500('val')
a.reset_state() a.reset_state()
for k in a.get_data(): for k in a:
cv2.imshow("haha", k[1].astype('uint8') * 255) cv2.imshow("haha", k[1].astype('uint8') * 255)
cv2.waitKey(1000) cv2.waitKey(1000)
...@@ -106,10 +106,10 @@ class CifarBase(RNGDataFlow): ...@@ -106,10 +106,10 @@ class CifarBase(RNGDataFlow):
self.dir = dir self.dir = dir
self.shuffle = shuffle self.shuffle = shuffle
def size(self): def __len__(self):
return 50000 if self.train_or_test == 'train' else 10000 return 50000 if self.train_or_test == 'train' else 10000
def get_data(self): def __iter__(self):
idxs = np.arange(len(self.data)) idxs = np.arange(len(self.data))
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
...@@ -171,7 +171,7 @@ if __name__ == '__main__': ...@@ -171,7 +171,7 @@ if __name__ == '__main__':
import cv2 import cv2
ds.reset_state() ds.reset_state()
for i, dp in enumerate(ds.get_data()): for i, dp in enumerate(ds):
if i == 100: if i == 100:
break break
img = dp[0] img = dp[0]
......
...@@ -163,10 +163,10 @@ class ILSVRC12Files(RNGDataFlow): ...@@ -163,10 +163,10 @@ class ILSVRC12Files(RNGDataFlow):
fname = os.path.join(self.full_dir, fname) fname = os.path.join(self.full_dir, fname)
assert os.path.isfile(fname), fname assert os.path.isfile(fname), fname
def size(self): def __len__(self):
return len(self.imglist) return len(self.imglist)
def get_data(self): def __iter__(self):
idxs = np.arange(len(self.imglist)) idxs = np.arange(len(self.imglist))
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
...@@ -251,8 +251,8 @@ class ILSVRC12(ILSVRC12Files): ...@@ -251,8 +251,8 @@ class ILSVRC12(ILSVRC12Files):
There are some CMYK / png images, but cv2 seems robust to them. There are some CMYK / png images, but cv2 seems robust to them.
https://github.com/tensorflow/models/blob/c0cd713f59cfe44fa049b3120c417cc4079c17e3/research/inception/inception/data/build_imagenet_data.py#L264-L300 https://github.com/tensorflow/models/blob/c0cd713f59cfe44fa049b3120c417cc4079c17e3/research/inception/inception/data/build_imagenet_data.py#L264-L300
""" """
def get_data(self): def __iter__(self):
for fname, label in super(ILSVRC12, self).get_data(): for fname, label in super(ILSVRC12, self).__iter__():
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
yield [im, label] yield [im, label]
...@@ -299,7 +299,7 @@ if __name__ == '__main__': ...@@ -299,7 +299,7 @@ if __name__ == '__main__':
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False) ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False)
ds.reset_state() ds.reset_state()
for k in ds.get_data(): for k in ds:
from IPython import embed from IPython import embed
embed() embed()
break break
...@@ -99,11 +99,11 @@ class Mnist(RNGDataFlow): ...@@ -99,11 +99,11 @@ class Mnist(RNGDataFlow):
't10k-images-idx3-ubyte.gz', 't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz') 't10k-labels-idx1-ubyte.gz')
def size(self): def __len__(self):
return self.images.shape[0] return self.images.shape[0]
def get_data(self): def __iter__(self):
idxs = list(range(self.size())) idxs = list(range(self.__len__()))
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
...@@ -133,7 +133,7 @@ class FashionMnist(Mnist): ...@@ -133,7 +133,7 @@ class FashionMnist(Mnist):
if __name__ == '__main__': if __name__ == '__main__':
ds = Mnist('train') ds = Mnist('train')
ds.reset_state() ds.reset_state()
for (img, label) in ds.get_data(): for (img, label) in ds:
from IPython import embed from IPython import embed
embed() embed()
break break
...@@ -49,10 +49,10 @@ class SVHNDigit(RNGDataFlow): ...@@ -49,10 +49,10 @@ class SVHNDigit(RNGDataFlow):
self.Y[self.Y == 10] = 0 self.Y[self.Y == 10] = 0
SVHNDigit._Cache[name] = (self.X, self.Y) SVHNDigit._Cache[name] = (self.X, self.Y)
def size(self): def __len__(self):
return self.X.shape[0] return self.X.shape[0]
def get_data(self): def __iter__(self):
n = self.X.shape[0] n = self.X.shape[0]
idxs = np.arange(n) idxs = np.arange(n)
if self.shuffle: if self.shuffle:
......
...@@ -43,7 +43,7 @@ def dump_dataflow_to_process_queue(df, size, nr_consumer): ...@@ -43,7 +43,7 @@ def dump_dataflow_to_process_queue(df, size, nr_consumer):
def run(self): def run(self):
self.df.reset_state() self.df.reset_state()
try: try:
for idx, dp in enumerate(self.df.get_data()): for idx, dp in enumerate(self.df):
self.q.put((idx, dp)) self.q.put((idx, dp))
finally: finally:
for _ in range(nr_consumer): for _ in range(nr_consumer):
......
...@@ -50,10 +50,10 @@ class HDF5Data(RNGDataFlow): ...@@ -50,10 +50,10 @@ class HDF5Data(RNGDataFlow):
self._size = lens[0] self._size = lens[0]
self.shuffle = shuffle self.shuffle = shuffle
def size(self): def __len__(self):
return self._size return self._size
def get_data(self): def __iter__(self):
idxs = list(range(self._size)) idxs = list(range(self._size))
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
...@@ -132,10 +132,10 @@ class LMDBData(RNGDataFlow): ...@@ -132,10 +132,10 @@ class LMDBData(RNGDataFlow):
super(LMDBData, self).reset_state() super(LMDBData, self).reset_state()
self._open_lmdb() self._open_lmdb()
def size(self): def __len__(self):
return self._size return self._size
def get_data(self): def __iter__(self):
with self._guard: with self._guard:
if not self._shuffle: if not self._shuffle:
c = self._txn.cursor() c = self._txn.cursor()
...@@ -231,11 +231,11 @@ class SVMLightData(RNGDataFlow): ...@@ -231,11 +231,11 @@ class SVMLightData(RNGDataFlow):
self.X = np.asarray(self.X.todense()) self.X = np.asarray(self.X.todense())
self.shuffle = shuffle self.shuffle = shuffle
def size(self): def __len__(self):
return len(self.y) return len(self.y)
def get_data(self): def __iter__(self):
idxs = np.arange(self.size()) idxs = np.arange(self.__len__())
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for id in idxs: for id in idxs:
...@@ -248,12 +248,12 @@ class TFRecordData(DataFlow): ...@@ -248,12 +248,12 @@ class TFRecordData(DataFlow):
self._path = path self._path = path
self._size = int(size) self._size = int(size)
def size(self): def __len__(self):
if self._size: if self._size:
return self._size return self._size
return super(TFRecordData, self).size() return len(super(TFRecordData, self))
def get_data(self): def __iter__(self):
gen = tf.python_io.tf_record_iterator(self._path) gen = tf.python_io.tf_record_iterator(self._path)
for dp in gen: for dp in gen:
yield loads(dp) yield loads(dp)
......
...@@ -63,10 +63,10 @@ class ImageFromFile(RNGDataFlow): ...@@ -63,10 +63,10 @@ class ImageFromFile(RNGDataFlow):
self.resize = resize self.resize = resize
self.shuffle = shuffle self.shuffle = shuffle
def size(self): def __len__(self):
return len(self.files) return len(self.files)
def get_data(self): def __iter__(self):
if self.shuffle: if self.shuffle:
self.rng.shuffle(self.files) self.rng.shuffle(self.files)
for f in self.files: for f in self.files:
......
...@@ -159,7 +159,7 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -159,7 +159,7 @@ class MultiProcessPrefetchData(ProxyDataFlow):
# reset all ds so each process will produce different data # reset all ds so each process will produce different data
self.ds.reset_state() self.ds.reset_state()
while True: while True:
for dp in self.ds.get_data(): for dp in self.ds:
self.queue.put(dp) self.queue.put(dp)
def __init__(self, ds, nr_prefetch, nr_proc): def __init__(self, ds, nr_prefetch, nr_proc):
...@@ -175,7 +175,7 @@ However, windows requires more strict picklability on processes, which may \ ...@@ -175,7 +175,7 @@ However, windows requires more strict picklability on processes, which may \
lead of failure on some of the code.") lead of failure on some of the code.")
super(MultiProcessPrefetchData, self).__init__(ds) super(MultiProcessPrefetchData, self).__init__(ds)
try: try:
self._size = ds.size() self._size = len(ds)
except NotImplementedError: except NotImplementedError:
self._size = -1 self._size = -1
self.nr_proc = nr_proc self.nr_proc = nr_proc
...@@ -191,7 +191,7 @@ lead of failure on some of the code.") ...@@ -191,7 +191,7 @@ lead of failure on some of the code.")
ensure_proc_terminate(self.procs) ensure_proc_terminate(self.procs)
start_proc_mask_signal(self.procs) start_proc_mask_signal(self.procs)
def get_data(self): def __iter__(self):
for k in itertools.count(): for k in itertools.count():
if self._size > 0 and k >= self._size: if self._size > 0 and k >= self._size:
break break
...@@ -264,7 +264,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -264,7 +264,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
socket.connect(self.conn_name) socket.connect(self.conn_name)
try: try:
while True: while True:
for dp in self.ds.get_data(): for dp in self.ds:
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
# sigint could still propagate here, e.g. when nested # sigint could still propagate here, e.g. when nested
except KeyboardInterrupt: except KeyboardInterrupt:
...@@ -291,17 +291,17 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -291,17 +291,17 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. " logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.") "This assumes the datapoints are i.i.d.")
try: try:
self._size = ds.size() self._size = ds.__len__()
except NotImplementedError: except NotImplementedError:
self._size = -1 self._size = -1
def _recv(self): def _recv(self):
return loads(self.socket.recv(copy=False)) return loads(self.socket.recv(copy=False))
def size(self): def __len__(self):
return self.ds.size() return self.ds.__len__()
def get_data(self): def __iter__(self):
with self._guard, _zmq_catch_error('PrefetchDataZMQ'): with self._guard, _zmq_catch_error('PrefetchDataZMQ'):
for k in itertools.count(): for k in itertools.count():
if self._size > 0 and k >= self._size: if self._size > 0 and k >= self._size:
...@@ -360,7 +360,7 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -360,7 +360,7 @@ class MultiThreadPrefetchData(DataFlow):
def run(self): def run(self):
self.df.reset_state() self.df.reset_state()
try: try:
for dp in self.df.get_data(): for dp in self.df:
if self.stopped(): if self.stopped():
return return
self.queue_put_stoppable(self.queue, dp) self.queue_put_stoppable(self.queue, dp)
...@@ -391,10 +391,10 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -391,10 +391,10 @@ class MultiThreadPrefetchData(DataFlow):
th.df.reset_state() th.df.reset_state()
th.start() th.start()
def size(self): def __len__(self):
return self.threads[0].size() return self.threads[0].__len__()
def get_data(self): def __iter__(self):
while True: while True:
yield self.queue.get() yield self.queue.get()
...@@ -419,8 +419,8 @@ class PlasmaPutData(ProxyDataFlow): ...@@ -419,8 +419,8 @@ class PlasmaPutData(ProxyDataFlow):
super(PlasmaPutData, self).reset_state() super(PlasmaPutData, self).reset_state()
self.client = plasma.connect(self._socket, "", 0) self.client = plasma.connect(self._socket, "", 0)
def get_data(self): def __iter__(self):
for dp in self.ds.get_data(): for dp in self.ds:
oid = self.client.put(dp) oid = self.client.put(dp)
yield [oid.binary()] yield [oid.binary()]
...@@ -440,8 +440,8 @@ class PlasmaGetData(ProxyDataFlow): ...@@ -440,8 +440,8 @@ class PlasmaGetData(ProxyDataFlow):
super(PlasmaGetData, self).reset_state() super(PlasmaGetData, self).reset_state()
self.client = plasma.connect(self._socket, "", 0) self.client = plasma.connect(self._socket, "", 0)
def get_data(self): def __iter__(self):
for dp in self.ds.get_data(): for dp in self.ds:
oid = plasma.ObjectID(dp[0]) oid = plasma.ObjectID(dp[0])
dp = self.client.get(oid) dp = self.client.get(oid)
yield dp yield dp
......
...@@ -61,7 +61,7 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -61,7 +61,7 @@ class _ParallelMapData(ProxyDataFlow):
if ret is not None: if ret is not None:
yield ret yield ret
self._iter = self.ds.get_data() # refresh self._iter = self.ds.__iter__() # refresh
for _ in range(self._buffer_size): for _ in range(self._buffer_size):
self._send(next(self._iter)) self._send(next(self._iter))
ret = self._recv() ret = self._recv()
...@@ -73,7 +73,7 @@ class _ParallelMapData(ProxyDataFlow): ...@@ -73,7 +73,7 @@ class _ParallelMapData(ProxyDataFlow):
for dp in self._iter: for dp in self._iter:
self._send(dp) self._send(dp)
yield self._recv_filter_none() yield self._recv_filter_none()
self._iter = self.ds.get_data() # refresh self._iter = self.ds.__iter__() # refresh
# first clear the buffer, then fill # first clear the buffer, then fill
for k in range(self._buffer_size): for k in range(self._buffer_size):
...@@ -100,11 +100,11 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -100,11 +100,11 @@ class MultiThreadMapData(_ParallelMapData):
2. Threads run in parallel and can take different time to run the 2. Threads run in parallel and can take different time to run the
mapping function. Therefore the order of datapoints won't be mapping function. Therefore the order of datapoints won't be
preserved, and datapoints from one pass of `df.get_data()` might get preserved, and datapoints from one pass of `df.__iter__()` might get
mixed with datapoints from the next pass. mixed with datapoints from the next pass.
You can use **strict mode**, where `MultiThreadMapData.get_data()` You can use **strict mode**, where `MultiThreadMapData.__iter__()`
is guaranteed to produce the exact set which `df.get_data()` is guaranteed to produce the exact set which `df.__iter__()`
produces. Although the order of data still isn't preserved. produces. Although the order of data still isn't preserved.
""" """
class _Worker(StoppableThread): class _Worker(StoppableThread):
...@@ -165,7 +165,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -165,7 +165,7 @@ class MultiThreadMapData(_ParallelMapData):
for t in self._threads: for t in self._threads:
t.start() t.start()
self._iter = self.ds.get_data() self._iter = self.ds.__iter__()
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
# Call once at the beginning, to ensure inq+outq has a total of buffer_size elements # Call once at the beginning, to ensure inq+outq has a total of buffer_size elements
...@@ -177,7 +177,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -177,7 +177,7 @@ class MultiThreadMapData(_ParallelMapData):
def _send(self, dp): def _send(self, dp):
self._in_queue.put(dp) self._in_queue.put(dp)
def get_data(self): def __iter__(self):
with self._guard: with self._guard:
if self._strict: if self._strict:
for dp in self.get_data_strict(): for dp in self.get_data_strict():
...@@ -208,11 +208,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -208,11 +208,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
Note: Note:
1. Processes run in parallel and can take different time to run the 1. Processes run in parallel and can take different time to run the
mapping function. Therefore the order of datapoints won't be mapping function. Therefore the order of datapoints won't be
preserved, and datapoints from one pass of `df.get_data()` might get preserved, and datapoints from one pass of `df.__iter__()` might get
mixed with datapoints from the next pass. mixed with datapoints from the next pass.
You can use **strict mode**, where `MultiProcessMapData.get_data()` You can use **strict mode**, where `MultiProcessMapData.__iter__()`
is guaranteed to produce the exact set which `df.get_data()` is guaranteed to produce the exact set which `df.__iter__()`
produces. Although the order of data still isn't preserved. produces. Although the order of data still isn't preserved.
""" """
class _Worker(mp.Process): class _Worker(mp.Process):
...@@ -267,7 +267,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -267,7 +267,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
for k in range(self.nr_proc)] for k in range(self.nr_proc)]
self.ds.reset_state() self.ds.reset_state()
self._iter = self.ds.get_data() self._iter = self.ds.__iter__()
self._start_processes() self._start_processes()
self._fill_buffer() # pre-fill the bufer self._fill_buffer() # pre-fill the bufer
...@@ -284,7 +284,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -284,7 +284,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
dp = loads(msg[1]) dp = loads(msg[1])
return dp return dp
def get_data(self): def __iter__(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'): with self._guard, _zmq_catch_error('MultiProcessMapData'):
if self._strict: if self._strict:
for dp in self.get_data_strict(): for dp in self.get_data_strict():
...@@ -362,13 +362,13 @@ class MultiProcessMapDataComponentSharedArray(DataFlow): ...@@ -362,13 +362,13 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
arr = mp.RawArray(ctype, int(np.prod(self.output_shape))) arr = mp.RawArray(ctype, int(np.prod(self.output_shape)))
return arr return arr
def size(self): def __len__(self):
return self.ds.size() return len(self.ds)
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
def get_data(self): def __iter__(self):
ds_itr = _repeat_iter(self.ds.get_data) ds_itr = _repeat_iter(self.ds.get_data)
with self._guard: with self._guard:
while True: while True:
...@@ -392,16 +392,16 @@ if __name__ == '__main__': ...@@ -392,16 +392,16 @@ if __name__ == '__main__':
def __init__(self, size): def __init__(self, size):
self._size = size self._size = size
def get_data(self): def __iter__(self):
for k in range(self._size): for k in range(self._size):
yield [k] yield [k]
def size(self): def __len__(self):
return self._size return self._size
ds = Zero(300) ds = Zero(300)
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1], strict=True) ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1], strict=True)
ds.reset_state() ds.reset_state()
for k in ds.get_data(): for k in ds:
print("Bang!", k) print("Bang!", k)
print("END!") print("END!")
...@@ -34,10 +34,10 @@ class FakeData(RNGDataFlow): ...@@ -34,10 +34,10 @@ class FakeData(RNGDataFlow):
assert len(self.dtype) == len(self.shapes) assert len(self.dtype) == len(self.shapes)
assert len(self.domain) == len(self.domain) assert len(self.domain) == len(self.domain)
def size(self): def __len__(self):
return self._size return self._size
def get_data(self): def __iter__(self):
if self.random: if self.random:
for _ in range(self._size): for _ in range(self._size):
val = [] val = []
...@@ -63,7 +63,7 @@ class DataFromQueue(DataFlow): ...@@ -63,7 +63,7 @@ class DataFromQueue(DataFlow):
""" """
self.queue = queue self.queue = queue
def get_data(self): def __iter__(self):
while True: while True:
yield self.queue.get() yield self.queue.get()
...@@ -81,10 +81,10 @@ class DataFromList(RNGDataFlow): ...@@ -81,10 +81,10 @@ class DataFromList(RNGDataFlow):
self.lst = lst self.lst = lst
self.shuffle = shuffle self.shuffle = shuffle
def size(self): def __len__(self):
return len(self.lst) return len(self.lst)
def get_data(self): def __iter__(self):
if not self.shuffle: if not self.shuffle:
for k in self.lst: for k in self.lst:
yield k yield k
...@@ -112,7 +112,7 @@ class DataFromGenerator(DataFlow): ...@@ -112,7 +112,7 @@ class DataFromGenerator(DataFlow):
if size is not None: if size is not None:
log_deprecated("DataFromGenerator(size=)", "It doesn't make much sense.", "2018-03-31") log_deprecated("DataFromGenerator(size=)", "It doesn't make much sense.", "2018-03-31")
def get_data(self): def __iter__(self):
# yield from # yield from
for dp in self._gen(): for dp in self._gen():
yield dp yield dp
...@@ -128,9 +128,9 @@ class DataFromIterable(DataFlow): ...@@ -128,9 +128,9 @@ class DataFromIterable(DataFlow):
self._itr = iterable self._itr = iterable
self._len = len(iterable) self._len = len(iterable)
def size(self): def __len__(self):
return self._len return self._len
def get_data(self): def __iter__(self):
for dp in self._itr: for dp in self._itr:
yield dp yield dp
...@@ -58,14 +58,14 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False): ...@@ -58,14 +58,14 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
q = deque(maxlen=INTERVAL) q = deque(maxlen=INTERVAL)
try: try:
total = df.size() total = len(df)
except NotImplementedError: except NotImplementedError:
total = 0 total = 0
tqdm_args = get_tqdm_kwargs(leave=True, smoothing=0.8) tqdm_args = get_tqdm_kwargs(leave=True, smoothing=0.8)
tqdm_args['bar_format'] = tqdm_args['bar_format'] + "{postfix}" tqdm_args['bar_format'] = tqdm_args['bar_format'] + "{postfix}"
while True: while True:
with tqdm.trange(total, **tqdm_args) as pbar: with tqdm.trange(total, **tqdm_args) as pbar:
for dp in df.get_data(): for dp in df:
start = time.time() start = time.time()
socket.send(dump_fn(dp), copy=False) socket.send(dump_fn(dp), copy=False)
q.append(time.time() - start) q.append(time.time() - start)
...@@ -117,7 +117,7 @@ class RemoteDataZMQ(DataFlow): ...@@ -117,7 +117,7 @@ class RemoteDataZMQ(DataFlow):
else: else:
socket.connect(addr) socket.connect(addr)
def get_data(self): def __iter__(self):
with self._guard: with self._guard:
try: try:
ctx = zmq.Context() ctx = zmq.Context()
......
...@@ -20,7 +20,7 @@ __all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer', 'HDF5Seria ...@@ -20,7 +20,7 @@ __all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer', 'HDF5Seria
def _reset_df_and_get_size(df): def _reset_df_and_get_size(df):
df.reset_state() df.reset_state()
try: try:
sz = df.size() sz = len(df)
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
return sz return sz
...@@ -57,7 +57,7 @@ class LMDBSerializer(): ...@@ -57,7 +57,7 @@ class LMDBSerializer():
# LMDB transaction is not exception-safe! # LMDB transaction is not exception-safe!
# although it has a context manager interface # although it has a context manager interface
txn = db.begin(write=True) txn = db.begin(write=True)
for idx, dp in enumerate(df.get_data()): for idx, dp in enumerate(df):
txn.put(u'{:08}'.format(idx).encode('ascii'), dumps(dp)) txn.put(u'{:08}'.format(idx).encode('ascii'), dumps(dp))
pbar.update() pbar.update()
if (idx + 1) % write_frequency == 0: if (idx + 1) % write_frequency == 0:
...@@ -101,7 +101,7 @@ class NumpySerializer(): ...@@ -101,7 +101,7 @@ class NumpySerializer():
buffer = [] buffer = []
size = _reset_df_and_get_size(df) size = _reset_df_and_get_size(df)
with get_tqdm(total=size) as pbar: with get_tqdm(total=size) as pbar:
for dp in df.get_data(): for dp in df:
buffer.append(dp) buffer.append(dp)
pbar.update() pbar.update()
np.savez_compressed(path, buffer=np.asarray(buffer, dtype=np.object)) np.savez_compressed(path, buffer=np.asarray(buffer, dtype=np.object))
...@@ -135,7 +135,7 @@ class TFRecordSerializer(): ...@@ -135,7 +135,7 @@ class TFRecordSerializer():
size = _reset_df_and_get_size(df) size = _reset_df_and_get_size(df)
with tf.python_io.TFRecordWriter(path) as writer, get_tqdm(total=size) as pbar: with tf.python_io.TFRecordWriter(path) as writer, get_tqdm(total=size) as pbar:
for dp in df.get_data(): for dp in df:
writer.write(_dumps(dp)) writer.write(_dumps(dp))
pbar.update() pbar.update()
...@@ -143,7 +143,7 @@ class TFRecordSerializer(): ...@@ -143,7 +143,7 @@ class TFRecordSerializer():
def load(path, size=None): def load(path, size=None):
""" """
Args: Args:
size (int): total number of records. If not provided, the returned dataflow will have no `size()`. size (int): total number of records. If not provided, the returned dataflow will have no `__len__()`.
It's needed because this metadata is not stored in the TFRecord file. It's needed because this metadata is not stored in the TFRecord file.
""" """
gen = tf.python_io.tf_record_iterator(path) gen = tf.python_io.tf_record_iterator(path)
...@@ -175,15 +175,16 @@ class HDF5Serializer(): ...@@ -175,15 +175,16 @@ class HDF5Serializer():
buffer = defaultdict(list) buffer = defaultdict(list)
with get_tqdm(total=size) as pbar: with get_tqdm(total=size) as pbar:
for dp in df.get_data(): for dp in df:
assert len(dp) == len(data_paths), "Datapoint has {} components!".format(len(dp)) assert len(dp) == len(data_paths), "Datapoint has {} components!".format(len(dp))
for k, el in zip(data_paths, dp): for k, el in zip(data_paths, dp):
buffer[k].append(el) buffer[k].append(el)
pbar.update() pbar.update()
with h5py.File(path, 'w') as hf, get_tqdm(total=size) as pbar: with h5py.File(path, 'w') as hf, get_tqdm(total=len(data_paths)) as pbar:
for data_path in data_paths: for data_path in data_paths:
hf.create_dataset(data_path, data=buffer[data_path]) hf.create_dataset(data_path, data=buffer[data_path])
pbar.update()
@staticmethod @staticmethod
def load(path, data_paths, shuffle=True): def load(path, data_paths, shuffle=True):
...@@ -221,7 +222,7 @@ if __name__ == '__main__': ...@@ -221,7 +222,7 @@ if __name__ == '__main__':
print(time.time()) print(time.time())
df = TFRecordSerializer.load('out.tfrecords', size=1000) df = TFRecordSerializer.load('out.tfrecords', size=1000)
df.reset_state() df.reset_state()
for idx, dp in enumerate(df.get_data()): for idx, dp in enumerate(df):
pass pass
print("TF Finished, ", idx) print("TF Finished, ", idx)
print(time.time()) print(time.time())
...@@ -230,7 +231,7 @@ if __name__ == '__main__': ...@@ -230,7 +231,7 @@ if __name__ == '__main__':
print(time.time()) print(time.time())
df = LMDBSerializer.load('out.lmdb') df = LMDBSerializer.load('out.lmdb')
df.reset_state() df.reset_state()
for idx, dp in enumerate(df.get_data()): for idx, dp in enumerate(df):
pass pass
print("LMDB Finished, ", idx) print("LMDB Finished, ", idx)
print(time.time()) print(time.time())
...@@ -239,7 +240,7 @@ if __name__ == '__main__': ...@@ -239,7 +240,7 @@ if __name__ == '__main__':
print(time.time()) print(time.time())
df = NumpySerializer.load('out.npz') df = NumpySerializer.load('out.npz')
df.reset_state() df.reset_state()
for idx, dp in enumerate(df.get_data()): for idx, dp in enumerate(df):
pass pass
print("Numpy Finished, ", idx) print("Numpy Finished, ", idx)
print(time.time()) print(time.time())
...@@ -248,7 +249,7 @@ if __name__ == '__main__': ...@@ -248,7 +249,7 @@ if __name__ == '__main__':
print(time.time()) print(time.time())
df = HDF5Serializer.load('out.h5') df = HDF5Serializer.load('out.h5')
df.reset_state() df.reset_state()
for idx, dp in enumerate(df.get_data()): for idx, dp in enumerate(df):
pass pass
print("HDF5 Finished, ", idx) print("HDF5 Finished, ", idx)
print(time.time()) print(time.time())
...@@ -77,7 +77,7 @@ class FeedInput(InputSource): ...@@ -77,7 +77,7 @@ class FeedInput(InputSource):
class _FeedCallback(Callback): class _FeedCallback(Callback):
def __init__(self, ds, placeholders): def __init__(self, ds, placeholders):
self._ds = ds self._ds = ds
self._itr = self._ds.get_data() self._itr = self._ds.__iter__()
self._placeholders = placeholders self._placeholders = placeholders
def _before_run(self, _): def _before_run(self, _):
...@@ -87,7 +87,7 @@ class FeedInput(InputSource): ...@@ -87,7 +87,7 @@ class FeedInput(InputSource):
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed) return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def _reset(self): def _reset(self):
self._itr = self._ds.get_data() self._itr = self._ds.__iter__()
def __init__(self, ds, infinite=True): def __init__(self, ds, infinite=True):
""" """
...@@ -105,7 +105,7 @@ class FeedInput(InputSource): ...@@ -105,7 +105,7 @@ class FeedInput(InputSource):
self._iter_ds = self.ds self._iter_ds = self.ds
def _size(self): def _size(self):
return self.ds.size() return len(self.ds)
def _setup(self, inputs): def _setup(self, inputs):
# placeholders as input are always safe to reuse. # placeholders as input are always safe to reuse.
...@@ -175,7 +175,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -175,7 +175,7 @@ class EnqueueThread(ShareSessionThread):
logger.info("{} Exited.".format(self.name)) logger.info("{} Exited.".format(self.name))
def reinitialize_dataflow(self): def reinitialize_dataflow(self):
self._itr = self.dataflow.get_data() self._itr = self.dataflow.__iter__()
def pause(self): def pause(self):
self._running.clear() self._running.clear()
...@@ -207,7 +207,7 @@ class QueueInput(FeedfreeInput): ...@@ -207,7 +207,7 @@ class QueueInput(FeedfreeInput):
self._started = False self._started = False
def _size(self): def _size(self):
return self.ds.size() return len(self.ds)
def _setup(self, inputs): def _setup(self, inputs):
self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._input_placehdrs = [v.build_placeholder_reuse() for v in inputs]
...@@ -225,7 +225,7 @@ class QueueInput(FeedfreeInput): ...@@ -225,7 +225,7 @@ class QueueInput(FeedfreeInput):
def refill_queue(self): def refill_queue(self):
""" """
Clear the queue, then call dataflow.get_data() again and fill into the queue. Clear the queue, then call dataflow.__iter__() again and fill into the queue.
""" """
self.thread.pause() # pause enqueue self.thread.pause() # pause enqueue
...@@ -292,7 +292,7 @@ class BatchQueueInput(QueueInput): ...@@ -292,7 +292,7 @@ class BatchQueueInput(QueueInput):
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
def _size(self): def _size(self):
return self.ds.size() // self.batch_size return len(self.ds) // self.batch_size
def _setup(self, inputs): def _setup(self, inputs):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
......
...@@ -67,11 +67,11 @@ class SimpleDatasetPredictor(DatasetPredictorBase): ...@@ -67,11 +67,11 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
def get_result(self): def get_result(self):
self.dataset.reset_state() self.dataset.reset_state()
try: try:
sz = self.dataset.size() sz = len(self.dataset)
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with get_tqdm(total=sz, disable=(sz == 0)) as pbar: with get_tqdm(total=sz, disable=(sz == 0)) as pbar:
for dp in self.dataset.get_data(): for dp in self.dataset:
res = self.predictor(*dp) res = self.predictor(*dp)
yield res yield res
pbar.update() pbar.update()
...@@ -146,7 +146,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -146,7 +146,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
def get_result(self): def get_result(self):
try: try:
sz = self.dataset.size() sz = len(self.dataset)
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with get_tqdm(total=sz, disable=(sz == 0)) as pbar: with get_tqdm(total=sz, disable=(sz == 0)) as pbar:
......
...@@ -131,13 +131,13 @@ class TrainConfig(object): ...@@ -131,13 +131,13 @@ class TrainConfig(object):
if steps_per_epoch is None: if steps_per_epoch is None:
try: try:
if dataflow is not None: if dataflow is not None:
steps_per_epoch = dataflow.size() steps_per_epoch = len(dataflow)
elif data is not None: elif data is not None:
steps_per_epoch = data.size() steps_per_epoch = data.size()
else: else:
raise NotImplementedError() raise NotImplementedError()
except NotImplementedError: except NotImplementedError:
logger.error("You must set `TrainConfig(steps_per_epoch)` if data.size() is not available.") logger.error("You must set `TrainConfig(steps_per_epoch)` if the size of your input is not available.")
raise raise
else: else:
steps_per_epoch = int(steps_per_epoch) steps_per_epoch = int(steps_per_epoch)
......
...@@ -214,7 +214,7 @@ def get_tqdm_kwargs(**kwargs): ...@@ -214,7 +214,7 @@ def get_tqdm_kwargs(**kwargs):
return default return default
def get_tqdm(**kwargs): def get_tqdm(*args, **kwargs):
""" Similar to :func:`get_tqdm_kwargs`, """ Similar to :func:`tqdm.tqdm()`,
but returns the tqdm object directly. """ but use tensorpack's default options to have consistent style. """
return tqdm(**get_tqdm_kwargs(**kwargs)) return tqdm(*args, **get_tqdm_kwargs(**kwargs))
...@@ -308,7 +308,7 @@ def dump_dataflow_images(df, index=0, batched=True, ...@@ -308,7 +308,7 @@ def dump_dataflow_images(df, index=0, batched=True,
df.reset_state() df.reset_state()
cnt = 0 cnt = 0
while True: while True:
for dp in df.get_data(): for dp in df:
if not batched: if not batched:
imgbatch = [dp[index]] imgbatch = [dp[index]]
else: else:
......
...@@ -31,10 +31,10 @@ class SeededFakeDataFlow(DataFlow): ...@@ -31,10 +31,10 @@ class SeededFakeDataFlow(DataFlow):
img = np.random.randn(28, 28, 3) img = np.random.randn(28, 28, 3)
self.cache.append([label, img]) self.cache.append([label, img])
def size(self): def __len__(self):
return self._size return self._size
def get_data(self): def __iter__(self):
for dp in self.cache: for dp in self.cache:
yield dp yield dp
...@@ -52,7 +52,7 @@ class SerializerTest(unittest.TestCase): ...@@ -52,7 +52,7 @@ class SerializerTest(unittest.TestCase):
ds_actual.reset_state() ds_actual.reset_state()
ds_expected.reset_state() ds_expected.reset_state()
for dp_expected, dp_actual in zip(ds_expected.get_data(), ds_actual.get_data()): for dp_expected, dp_actual in zip(ds_expected.__iter__(), ds_actual.__iter__()):
self.assertEqual(dp_expected[0], dp_actual[0]) self.assertEqual(dp_expected[0], dp_actual[0])
self.assertTrue(np.allclose(dp_expected[1], dp_actual[1])) self.assertTrue(np.allclose(dp_expected[1], dp_actual[1]))
except ImportError: except ImportError:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment