Commit 1567c2dc authored by Yuxin Wu's avatar Yuxin Wu

Let LocallyShuffleData not depend on __len__

parent 68b8a7b7
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
DataFlow is a library to build Python iterators for efficient data loading. DataFlow is a library to build Python iterators for efficient data loading.
**Definition**: A DataFlow is a idiomatic Python container object that has a `__iter__()` generator method, which yields `datapoints` and a `__len__()` method returning the size of the flow. **Definition**: A DataFlow is a idiomatic Python container object that has a `__iter__()` generator method,
which yields `datapoints` and optionally a `__len__()` method returning the size of the flow.
A datapoint is a **list** of Python objects which are called the `components` of a datapoint. A datapoint is a **list** of Python objects which are called the `components` of a datapoint.
**Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method **Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method
...@@ -40,6 +41,7 @@ df = PrefetchDataZMQ(df, 3) ...@@ -40,6 +41,7 @@ df = PrefetchDataZMQ(df, 3)
You can find more complicated DataFlow in the [ImageNet training script](../examples/ImageNetModels/imagenet_utils.py) You can find more complicated DataFlow in the [ImageNet training script](../examples/ImageNetModels/imagenet_utils.py)
with all the data preprocessing. with all the data preprocessing.
### Work with Your Data
Unless you are working with standard data types (image folders, LMDB, etc), Unless you are working with standard data types (image folders, LMDB, etc),
you would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your data format. you would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your data format.
See [another tutorial](extend/dataflow.html) for simple instructions on writing a DataFlow. See [another tutorial](extend/dataflow.html) for simple instructions on writing a DataFlow.
...@@ -58,7 +60,7 @@ the rest of the data pipeline. ...@@ -58,7 +60,7 @@ the rest of the data pipeline.
Nevertheless, tensorpack supports data loading with native TF operators / TF datasets as well. Nevertheless, tensorpack supports data loading with native TF operators / TF datasets as well.
### Use DataFlow (outside Tensorpack) ### Use DataFlow outside Tensorpack
Normally, tensorpack `InputSource` interface links DataFlow to the graph for training. Normally, tensorpack `InputSource` interface links DataFlow to the graph for training.
If you use DataFlow in other places such as your custom code, call `reset_state()` first to initialize it, If you use DataFlow in other places such as your custom code, call `reset_state()` first to initialize it,
and then use the generator however you like: and then use the generator however you like:
......
...@@ -13,14 +13,14 @@ The easiest way to create a DataFlow to load custom data, is to wrap a custom ge ...@@ -13,14 +13,14 @@ The easiest way to create a DataFlow to load custom data, is to wrap a custom ge
```python ```python
def my_data_loader(): def my_data_loader():
while True: while True:
# load data from somewhere # load data from somewhere with Python
yield [my_array, my_label] yield [my_array, my_label]
dataflow = DataFromGenerator(my_data_loader) dataflow = DataFromGenerator(my_data_loader)
``` ```
To write more complicated DataFlow, you need to inherit the base `DataFlow` class. To write more complicated DataFlow, you need to inherit the base `DataFlow` class.
Usually, you just need to implement the `get_data()` method which yields a datapoint every time. Usually, you just need to implement the `__iter__()` method which yields a datapoint every time.
```python ```python
class MyDataFlow(DataFlow): class MyDataFlow(DataFlow):
def __iter__(self): def __iter__(self):
...@@ -32,7 +32,9 @@ class MyDataFlow(DataFlow): ...@@ -32,7 +32,9 @@ class MyDataFlow(DataFlow):
Optionally, you can implement the following two methods: Optionally, you can implement the following two methods:
+ `size()`. Return the number of elements the generator can produce. Certain tensorpack features might use it. + `__len__()`. Return the number of elements the generator can produce. Certain tensorpack features might need it.
This is optional, and even when implemented, it is
not guaranteed to be an accurate length because it's impossible to know the length of certain generator.
+ `reset_state()`. It is guaranteed that the actual process which runs a DataFlow will invoke this method before using it. + `reset_state()`. It is guaranteed that the actual process which runs a DataFlow will invoke this method before using it.
So if this DataFlow needs to do something after a `fork()`, you should put it here. So if this DataFlow needs to do something after a `fork()`, you should put it here.
...@@ -61,3 +63,8 @@ class ProcessingDataFlow(DataFlow): ...@@ -61,3 +63,8 @@ class ProcessingDataFlow(DataFlow):
# do something # do something
yield new_datapoint yield new_datapoint
``` ```
Some built-in dataflows, e.g.
[MapData](../../modules/dataflow.html#tensorpack.dataflow.MapData) and
[MapDataComponent](../../https://tensorpack.readthedocs.io/modules/dataflow.html#tensorpack.dataflow.MapDataComponent)
can do the above type of data processing for you.
...@@ -79,7 +79,7 @@ Note some __common problems__ when using these trainers: ...@@ -79,7 +79,7 @@ Note some __common problems__ when using these trainers:
1. In each iteration, instead of taking one tensor for all GPUs and split, 1. In each iteration, instead of taking one tensor for all GPUs and split,
all GPUs take tensors from the `InputSource`. all GPUs take tensors from the `InputSource`.
So the total batch size would become ``(batch size of InputSource) * #GPU``. So the total batch size across all GPUs would become ``(batch size of InputSource) * #GPU``.
Splitting a tensor for data-parallel training makes no sense at all. First, Splitting a tensor for data-parallel training makes no sense at all. First,
it wastes time because typically data is concatenated into batches by the user. it wastes time because typically data is concatenated into batches by the user.
......
...@@ -252,12 +252,19 @@ class FixedSizeData(ProxyDataFlow): ...@@ -252,12 +252,19 @@ class FixedSizeData(ProxyDataFlow):
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
""" """
Apply a mapper/filter on the DataFlow. Apply a mapper/filter on the datapoints of a DataFlow.
Note: Note:
1. Please make sure func doesn't modify the components 1. Please make sure func doesn't modify its arguments in place,
unless you're certain it's safe. unless you're certain it's safe.
2. If you discard some datapoints, ``len(ds)`` will be incorrect. 2. If you discard some datapoints, ``len(ds)`` will be incorrect.
Example:
.. code-block:: none
ds = Mnist('train)
ds = MapData(ds, lambda dp: [dp[0] * 255, dp[1]])
""" """
def __init__(self, ds, func): def __init__(self, ds, func):
...@@ -283,9 +290,16 @@ class MapDataComponent(MapData): ...@@ -283,9 +290,16 @@ class MapDataComponent(MapData):
Note: Note:
1. This dataflow itself doesn't modify the datapoints. 1. This dataflow itself doesn't modify the datapoints.
But please make sure func doesn't modify the components But please make sure func doesn't modify its arguments in place,
unless you're certain it's safe. unless you're certain it's safe.
2. If you discard some datapoints, ``len(ds)`` will be incorrect. 2. If you discard some datapoints, ``len(ds)`` will be incorrect.
Example:
.. code-block:: none
ds = Mnist('train)
ds = MapDataComponent(ds, lambda img: img * 255, 0)
""" """
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
""" """
...@@ -556,10 +570,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -556,10 +570,10 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
buffer_size (int): size of the buffer. buffer_size (int): size of the buffer.
nr_reuse (int): reuse each datapoints several times to improve nr_reuse (int): duplicate each datapoints several times into the buffer to improve
speed, but may hurt your model. speed, but may hurt your model.
shuffle_interval (int): shuffle the buffer after this many shuffle_interval (int): shuffle the buffer after this many
datapoints went through it. Frequent shuffle on large buffer datapoints were produced from the given dataflow. Frequent shuffle on large buffer
may affect speed, but infrequent shuffle may affect may affect speed, but infrequent shuffle may affect
randomness. Defaults to buffer_size / 3 randomness. Defaults to buffer_size / 3
""" """
...@@ -574,32 +588,23 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -574,32 +588,23 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def reset_state(self): def reset_state(self):
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
RNGDataFlow.reset_state(self) RNGDataFlow.reset_state(self)
self.ds_itr = RepeatedData(self.ds, -1).__iter__()
self.current_cnt = 0 self.current_cnt = 0
def _add_data(self): def __len__(self):
dp = next(self.ds_itr) return len(self.ds) * self.nr_reuse
for _ in range(self.nr_reuse):
self.q.append(dp)
def __iter__(self): def __iter__(self):
with self._guard: with self._guard:
# fill queue for i, dp in enumerate(self.ds):
while self.q.maxlen > len(self.q): # fill queue
self._add_data() if i % self.shuffle_interval == 0:
self.rng.shuffle(self.q)
sz = self.__len__() if self.q.maxlen > len(self.q):
cnt = 0 self.q.extend([dp] * self.nr_reuse)
while True: continue
self.rng.shuffle(self.q) for _ in range(self.nr_reuse):
for _ in range(self.shuffle_interval): yield self.q.popleft()
# the inner loop maintains the queue size (almost) unchanged self.q.append(dp)
for _ in range(self.nr_reuse):
yield self.q.popleft()
cnt += self.nr_reuse
if cnt >= sz:
return
self._add_data()
class CacheData(ProxyDataFlow): class CacheData(ProxyDataFlow):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment