Commit 985236c3 authored by Yuxin Wu's avatar Yuxin Wu

be clear about "reinitialize" iterator and "reset" dataflow.

parent 09c45084
### Write a DataFlow
There are several existing DataFlow, e.g. ImageFromFile, DataFromList, which you can
use if your data format is simple.
However in general, you will probably need to write a new DataFlow to produce data for your task.
There are several existing DataFlow, e.g. [ImageFromFile](../../modules/dataflow.html#tensorpack.dataflow.ImageFromFile),
[DataFromList](../../http://tensorpack.readthedocs.io/en/latest/modules/dataflow.html#tensorpack.dataflow.DataFromList),
which you can use if your data format is simple.
In general, you probably need to write a source DataFlow to produce data for your task,
and then compose it with existing modules (e.g. mapping, batching, prefetching, ...).
Usually, you just need to implement the `get_data()` method which yields a datapoint every time.
```python
......@@ -17,7 +19,7 @@ class MyDataFlow(DataFlow):
Optionally, you can implement the following two methods:
+ `size()`. Return the number of elements the generator can produce. Certain tensorpack features might require this.
+ `size()`. Return the number of elements the generator can produce. Certain tensorpack features might use it.
+ `reset_state()`. It is guaranteed that the actual process which runs a DataFlow will invoke this method before using it.
So if this DataFlow needs to do something after a `fork()`, you should put it here.
......@@ -26,9 +28,9 @@ Optionally, you can implement the following two methods:
Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you.
You can subclass `RNGDataFlow` to access `self.rng` whose seed has been taken care of.
With a "low-level" DataFlow defined like above, you can then compose it with existing modules (e.g. batching, prefetching, ...).
The convention is that, `reset_state()` must be called once and usually only once for each DataFlow instance.
To reinitialize the dataflow (i.e. get a new iterator from the beginning), simply call `get_data()` again.
DataFlow implementations for several well-known datasets are provided in the
[dataflow.dataset](../../modules/dataflow.dataset.html)
module, you can take them as a reference.
......@@ -71,7 +71,6 @@ class FeedInput(InputSource):
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def _reset(self):
self._ds.reset_state()
self._itr = self._ds.get_data()
def __init__(self, ds, infinite=True):
......@@ -132,7 +131,7 @@ class EnqueueThread(ShareSessionThread):
def run(self):
with self.default_sess():
try:
self._itr = self.dataflow.get_data()
self.reinitialize_dataflow()
while True:
# pausable loop
self._lock.acquire()
......@@ -155,8 +154,7 @@ class EnqueueThread(ShareSessionThread):
pass
logger.info("{} Exited.".format(self.name))
def reset_dataflow(self):
self.dataflow.reset_state()
def reinitialize_dataflow(self):
self._itr = self.dataflow.get_data()
def pause(self):
......@@ -217,7 +215,7 @@ class QueueInput(FeedfreeInput):
pass
# reset dataflow, start thread
self.thread.reset_dataflow()
self.thread.reinitialize_dataflow()
self.thread.resume()
def _create_ema_callback(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment