Commit 985236c3 authored by Yuxin Wu's avatar Yuxin Wu

be clear about "reinitialize" iterator and "reset" dataflow.

parent 09c45084
### Write a DataFlow ### Write a DataFlow
There are several existing DataFlow, e.g. ImageFromFile, DataFromList, which you can There are several existing DataFlow, e.g. [ImageFromFile](../../modules/dataflow.html#tensorpack.dataflow.ImageFromFile),
use if your data format is simple. [DataFromList](../../http://tensorpack.readthedocs.io/en/latest/modules/dataflow.html#tensorpack.dataflow.DataFromList),
However in general, you will probably need to write a new DataFlow to produce data for your task. which you can use if your data format is simple.
In general, you probably need to write a source DataFlow to produce data for your task,
and then compose it with existing modules (e.g. mapping, batching, prefetching, ...).
Usually, you just need to implement the `get_data()` method which yields a datapoint every time. Usually, you just need to implement the `get_data()` method which yields a datapoint every time.
```python ```python
...@@ -17,7 +19,7 @@ class MyDataFlow(DataFlow): ...@@ -17,7 +19,7 @@ class MyDataFlow(DataFlow):
Optionally, you can implement the following two methods: Optionally, you can implement the following two methods:
+ `size()`. Return the number of elements the generator can produce. Certain tensorpack features might require this. + `size()`. Return the number of elements the generator can produce. Certain tensorpack features might use it.
+ `reset_state()`. It is guaranteed that the actual process which runs a DataFlow will invoke this method before using it. + `reset_state()`. It is guaranteed that the actual process which runs a DataFlow will invoke this method before using it.
So if this DataFlow needs to do something after a `fork()`, you should put it here. So if this DataFlow needs to do something after a `fork()`, you should put it here.
...@@ -26,9 +28,9 @@ Optionally, you can implement the following two methods: ...@@ -26,9 +28,9 @@ Optionally, you can implement the following two methods:
Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you. Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you.
You can subclass `RNGDataFlow` to access `self.rng` whose seed has been taken care of. You can subclass `RNGDataFlow` to access `self.rng` whose seed has been taken care of.
With a "low-level" DataFlow defined like above, you can then compose it with existing modules (e.g. batching, prefetching, ...). The convention is that, `reset_state()` must be called once and usually only once for each DataFlow instance.
To reinitialize the dataflow (i.e. get a new iterator from the beginning), simply call `get_data()` again.
DataFlow implementations for several well-known datasets are provided in the DataFlow implementations for several well-known datasets are provided in the
[dataflow.dataset](../../modules/dataflow.dataset.html) [dataflow.dataset](../../modules/dataflow.dataset.html)
module, you can take them as a reference. module, you can take them as a reference.
...@@ -71,7 +71,6 @@ class FeedInput(InputSource): ...@@ -71,7 +71,6 @@ class FeedInput(InputSource):
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed) return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def _reset(self): def _reset(self):
self._ds.reset_state()
self._itr = self._ds.get_data() self._itr = self._ds.get_data()
def __init__(self, ds, infinite=True): def __init__(self, ds, infinite=True):
...@@ -132,7 +131,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -132,7 +131,7 @@ class EnqueueThread(ShareSessionThread):
def run(self): def run(self):
with self.default_sess(): with self.default_sess():
try: try:
self._itr = self.dataflow.get_data() self.reinitialize_dataflow()
while True: while True:
# pausable loop # pausable loop
self._lock.acquire() self._lock.acquire()
...@@ -155,8 +154,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -155,8 +154,7 @@ class EnqueueThread(ShareSessionThread):
pass pass
logger.info("{} Exited.".format(self.name)) logger.info("{} Exited.".format(self.name))
def reset_dataflow(self): def reinitialize_dataflow(self):
self.dataflow.reset_state()
self._itr = self.dataflow.get_data() self._itr = self.dataflow.get_data()
def pause(self): def pause(self):
...@@ -217,7 +215,7 @@ class QueueInput(FeedfreeInput): ...@@ -217,7 +215,7 @@ class QueueInput(FeedfreeInput):
pass pass
# reset dataflow, start thread # reset dataflow, start thread
self.thread.reset_dataflow() self.thread.reinitialize_dataflow()
self.thread.resume() self.thread.resume()
def _create_ema_callback(self): def _create_ema_callback(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment