Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
985236c3
Commit
985236c3
authored
Nov 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
be clear about "reinitialize" iterator and "reset" dataflow.
parent
09c45084
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
11 deletions
+11
-11
docs/tutorial/extend/dataflow.md
docs/tutorial/extend/dataflow.md
+8
-6
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+3
-5
No files found.
docs/tutorial/extend/dataflow.md
View file @
985236c3
### Write a DataFlow
There are several existing DataFlow, e.g. ImageFromFile, DataFromList, which you can
use if your data format is simple.
However in general, you will probably need to write a new DataFlow to produce data for your task.
There are several existing DataFlow, e.g.
[
ImageFromFile
](
../../modules/dataflow.html#tensorpack.dataflow.ImageFromFile
)
,
[
DataFromList
](
../../http://tensorpack.readthedocs.io/en/latest/modules/dataflow.html#tensorpack.dataflow.DataFromList
)
,
which you can use if your data format is simple.
In general, you probably need to write a source DataFlow to produce data for your task,
and then compose it with existing modules (e.g. mapping, batching, prefetching, ...).
Usually, you just need to implement the
`get_data()`
method which yields a datapoint every time.
```
python
...
...
@@ -17,7 +19,7 @@ class MyDataFlow(DataFlow):
Optionally, you can implement the following two methods:
+
`size()`
. Return the number of elements the generator can produce. Certain tensorpack features might
require this
.
+
`size()`
. Return the number of elements the generator can produce. Certain tensorpack features might
use it
.
+
`reset_state()`
. It is guaranteed that the actual process which runs a DataFlow will invoke this method before using it.
So if this DataFlow needs to do something after a
`fork()`
, you should put it here.
...
...
@@ -26,9 +28,9 @@ Optionally, you can implement the following two methods:
Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you.
You can subclass `RNGDataFlow` to access `self.rng` whose seed has been taken care of.
With a "low-level" DataFlow defined like above, you can then compose it with existing modules (e.g. batching, prefetching, ...).
The convention is that,
`reset_state()`
must be called once and usually only once for each DataFlow instance.
To reinitialize the dataflow (i.e. get a new iterator from the beginning), simply call
`get_data()`
again.
DataFlow implementations for several well-known datasets are provided in the
[
dataflow.dataset
](
../../modules/dataflow.dataset.html
)
module, you can take them as a reference.
tensorpack/input_source/input_source.py
View file @
985236c3
...
...
@@ -71,7 +71,6 @@ class FeedInput(InputSource):
return
tf
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
feed
)
def
_reset
(
self
):
self
.
_ds
.
reset_state
()
self
.
_itr
=
self
.
_ds
.
get_data
()
def
__init__
(
self
,
ds
,
infinite
=
True
):
...
...
@@ -132,7 +131,7 @@ class EnqueueThread(ShareSessionThread):
def
run
(
self
):
with
self
.
default_sess
():
try
:
self
.
_itr
=
self
.
dataflow
.
get_data
()
self
.
reinitialize_dataflow
()
while
True
:
# pausable loop
self
.
_lock
.
acquire
()
...
...
@@ -155,8 +154,7 @@ class EnqueueThread(ShareSessionThread):
pass
logger
.
info
(
"{} Exited."
.
format
(
self
.
name
))
def
reset_dataflow
(
self
):
self
.
dataflow
.
reset_state
()
def
reinitialize_dataflow
(
self
):
self
.
_itr
=
self
.
dataflow
.
get_data
()
def
pause
(
self
):
...
...
@@ -217,7 +215,7 @@ class QueueInput(FeedfreeInput):
pass
# reset dataflow, start thread
self
.
thread
.
re
set
_dataflow
()
self
.
thread
.
re
initialize
_dataflow
()
self
.
thread
.
resume
()
def
_create_ema_callback
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment