Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
79b9d0eb
Commit
79b9d0eb
authored
Aug 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Guard some stateful dataflow with non-reentrancy
parent
6c905896
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
37 deletions
+62
-37
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+18
-1
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+1
-0
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+43
-36
No files found.
tensorpack/dataflow/base.py
View file @
79b9d0eb
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# File: base.py
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
threading
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
import
six
import
six
from
..utils.utils
import
get_rng
from
..utils.utils
import
get_rng
...
@@ -20,6 +20,23 @@ class DataFlowTerminated(BaseException):
...
@@ -20,6 +20,23 @@ class DataFlowTerminated(BaseException):
pass
pass
class
DataFlowReentrantGuard
(
object
):
"""
A tool to enforce thread-level non-reentrancy on DataFlow.
"""
def
__init__
(
self
):
self
.
_lock
=
threading
.
Lock
()
def
__enter__
(
self
):
self
.
_succ
=
self
.
_lock
.
acquire
(
blocking
=
False
)
if
not
self
.
_succ
:
raise
threading
.
ThreadError
(
"This DataFlow cannot be reused under different threads!"
)
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
_lock
.
release
()
return
False
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
DataFlow
(
object
):
class
DataFlow
(
object
):
""" Base class for all DataFlow """
""" Base class for all DataFlow """
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
79b9d0eb
...
@@ -128,6 +128,7 @@ class ILSVRC12Files(RNGDataFlow):
...
@@ -128,6 +128,7 @@ class ILSVRC12Files(RNGDataFlow):
self
.
imglist
=
meta
.
get_image_list
(
name
,
dir_structure
)
self
.
imglist
=
meta
.
get_image_list
(
name
,
dir_structure
)
for
fname
,
_
in
self
.
imglist
[:
10
]:
for
fname
,
_
in
self
.
imglist
[:
10
]:
fname
=
os
.
path
.
join
(
self
.
full_dir
,
fname
)
assert
os
.
path
.
isfile
(
fname
),
fname
assert
os
.
path
.
isfile
(
fname
),
fname
def
size
(
self
):
def
size
(
self
):
...
...
tensorpack/dataflow/prefetch.py
View file @
79b9d0eb
...
@@ -11,7 +11,7 @@ import uuid
...
@@ -11,7 +11,7 @@ import uuid
import
os
import
os
import
zmq
import
zmq
from
.base
import
ProxyDataFlow
,
DataFlowTerminated
from
.base
import
ProxyDataFlow
,
DataFlowTerminated
,
DataFlowReentrantGuard
from
..utils.concurrency
import
(
ensure_proc_terminate
,
from
..utils.concurrency
import
(
ensure_proc_terminate
,
mask_sigint
,
start_proc_mask_signal
,
mask_sigint
,
start_proc_mask_signal
,
StoppableThread
)
StoppableThread
)
...
@@ -74,6 +74,8 @@ class PrefetchData(ProxyDataFlow):
...
@@ -74,6 +74,8 @@ class PrefetchData(ProxyDataFlow):
self
.
_size
=
-
1
self
.
_size
=
-
1
self
.
nr_proc
=
nr_proc
self
.
nr_proc
=
nr_proc
self
.
nr_prefetch
=
nr_prefetch
self
.
nr_prefetch
=
nr_prefetch
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
queue
=
mp
.
Queue
(
self
.
nr_prefetch
)
self
.
queue
=
mp
.
Queue
(
self
.
nr_prefetch
)
self
.
procs
=
[
PrefetchProcess
(
self
.
ds
,
self
.
queue
)
self
.
procs
=
[
PrefetchProcess
(
self
.
ds
,
self
.
queue
)
for
_
in
range
(
self
.
nr_proc
)]
for
_
in
range
(
self
.
nr_proc
)]
...
@@ -81,6 +83,7 @@ class PrefetchData(ProxyDataFlow):
...
@@ -81,6 +83,7 @@ class PrefetchData(ProxyDataFlow):
start_proc_mask_signal
(
self
.
procs
)
start_proc_mask_signal
(
self
.
procs
)
def
get_data
(
self
):
def
get_data
(
self
):
with
self
.
_guard
:
for
k
in
itertools
.
count
():
for
k
in
itertools
.
count
():
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
break
break
...
@@ -155,9 +158,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -155,9 +158,11 @@ class PrefetchDataZMQ(ProxyDataFlow):
self
.
nr_proc
=
nr_proc
self
.
nr_proc
=
nr_proc
self
.
_hwm
=
hwm
self
.
_hwm
=
hwm
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
_setup_done
=
False
self
.
_setup_done
=
False
def
get_data
(
self
):
def
get_data
(
self
):
with
self
.
_guard
:
try
:
try
:
for
k
in
itertools
.
count
():
for
k
in
itertools
.
count
():
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
...
@@ -315,6 +320,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -315,6 +320,7 @@ class ThreadedMapData(ProxyDataFlow):
t
.
start
()
t
.
start
()
self
.
_iter
=
self
.
_iter_ds
.
get_data
()
self
.
_iter
=
self
.
_iter_ds
.
get_data
()
self
.
_guard
=
DataFlowReentrantGuard
()
# only call once, to ensure inq+outq has a total of buffer_size elements
# only call once, to ensure inq+outq has a total of buffer_size elements
self
.
_fill_buffer
()
self
.
_fill_buffer
()
...
@@ -332,6 +338,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -332,6 +338,7 @@ class ThreadedMapData(ProxyDataFlow):
raise
raise
def
get_data
(
self
):
def
get_data
(
self
):
with
self
.
_guard
:
for
dp
in
self
.
_iter
:
for
dp
in
self
.
_iter
:
self
.
_in_queue
.
put
(
dp
)
self
.
_in_queue
.
put
(
dp
)
yield
self
.
_out_queue
.
get
()
yield
self
.
_out_queue
.
get
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment