Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2e7dd9c7
Commit
2e7dd9c7
authored
Aug 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add a third note in prefetch about forking
parent
d6f0c57a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
8 deletions
+28
-8
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+3
-3
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+5
-0
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+20
-5
No files found.
tensorpack/callbacks/inference_runner.py
View file @
2e7dd9c7
...
...
@@ -16,7 +16,7 @@ from six.moves import range
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.develop
import
deprecated
from
..dataflow.base
import
DataFlow
,
DataFlowTerminated
from
..dataflow.base
import
DataFlow
from
..graph_builder.input_source_base
import
InputSource
from
..graph_builder.input_source
import
(
...
...
@@ -118,8 +118,8 @@ class InferenceRunnerBase(Callback):
try
:
for
_
in
tqdm
.
trange
(
self
.
_size
,
**
get_tqdm_kwargs
()):
self
.
_hooked_sess
.
run
(
fetches
=
[])
except
(
StopIteration
,
DataFlowTerminated
,
tf
.
errors
.
CancelledError
,
tf
.
errors
.
OutOfRangeError
):
except
(
StopIteration
,
tf
.
errors
.
CancelledError
,
tf
.
errors
.
OutOfRangeError
):
logger
.
error
(
"[InferenceRunner] input stopped before reaching its size()! "
+
msg
)
raise
...
...
tensorpack/dataflow/base.py
View file @
2e7dd9c7
...
...
@@ -12,6 +12,11 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
class
DataFlowTerminated
(
BaseException
):
"""
An exception indicating that the DataFlow is unable to produce any more data:
calling :meth:`get_data` will not give a valid iterator any more.
In most DataFlow this will not be raised.
"""
pass
...
...
tensorpack/dataflow/prefetch.py
View file @
2e7dd9c7
...
...
@@ -50,8 +50,13 @@ class PrefetchData(ProxyDataFlow):
Note:
1. This is significantly slower than :class:`PrefetchDataZMQ` when data is large.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df,
a),
b)``.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df,
nr_proc=a), nr_proc=
b)``.
A total of ``a`` instances of ``df`` worker processes will be created.
This is different from the behavior of :class`PrefetchDataZMQ`
3. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
(e.g. you are likely to see duplicated datapoints at the beginning)
"""
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
"""
...
...
@@ -115,6 +120,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
1. Once :meth:`reset_state` is called, this dataflow becomes not fork-safe.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchDataZMQ(df, a), b)``.
A total of ``a * b`` instances of ``df`` worker processes will be created.
3. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
(e.g. you are likely to see duplicated datapoints at the beginning)
"""
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
hwm
=
50
):
"""
...
...
@@ -234,6 +243,7 @@ class ThreadedMapData(ProxyDataFlow):
self
.
inq
=
inq
self
.
outq
=
outq
self
.
func
=
map_func
self
.
daemon
=
True
def
run
(
self
):
while
not
self
.
stopped
():
...
...
@@ -251,7 +261,8 @@ class ThreadedMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer
"""
super
(
ThreadedMapData
,
self
)
.
__init__
(
ds
)
self
.
infinite_ds
=
RepeatedData
(
ds
,
-
1
)
self
.
_iter_ds
=
RepeatedData
(
ds
,
-
1
)
self
.
nr_thread
=
nr_thread
self
.
buffer_size
=
buffer_size
self
.
map_func
=
map_func
...
...
@@ -271,15 +282,19 @@ class ThreadedMapData(ProxyDataFlow):
t
.
start
()
# fill the buffer
self
.
_itr
=
self
.
infinite
_ds
.
get_data
()
self
.
_itr
=
self
.
_iter
_ds
.
get_data
()
self
.
_fill_buffer
()
def
_fill_buffer
(
self
):
n
=
self
.
buffer_size
-
self
.
_in_queue
.
qsize
()
-
self
.
_out_queue
.
qsize
()
if
n
<=
0
:
return
for
_
in
range
(
n
):
self
.
_in_queue
.
put
(
next
(
self
.
_itr
))
try
:
for
_
in
range
(
n
):
self
.
_in_queue
.
put
(
next
(
self
.
_itr
))
except
StopIteration
:
logger
.
error
(
"[ThreadedMapData] buffer_size cannot be larger than the size of the DataFlow!"
)
raise
def
get_data
(
self
):
self
.
_fill_buffer
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment