Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f6313a07
Commit
f6313a07
authored
May 15, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Move dataflow reentrant guard to reset_state() since it's not pickleable
parent
e25cbf1a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
8 deletions
+8
-8
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+4
-4
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+1
-1
tensorpack/dataflow/parallel.py
tensorpack/dataflow/parallel.py
+1
-1
tensorpack/dataflow/parallel_map.py
tensorpack/dataflow/parallel_map.py
+2
-2
No files found.
tensorpack/dataflow/common.py
View file @
f6313a07
...
@@ -193,11 +193,11 @@ class BatchDataByShape(BatchData):
...
@@ -193,11 +193,11 @@ class BatchDataByShape(BatchData):
"""
"""
super
(
BatchDataByShape
,
self
)
.
__init__
(
ds
,
batch_size
,
remainder
=
False
)
super
(
BatchDataByShape
,
self
)
.
__init__
(
ds
,
batch_size
,
remainder
=
False
)
self
.
idx
=
idx
self
.
idx
=
idx
self
.
_guard
=
DataFlowReentrantGuard
()
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
BatchDataByShape
,
self
)
.
reset_state
()
super
(
BatchDataByShape
,
self
)
.
reset_state
()
self
.
holder
=
defaultdict
(
list
)
self
.
holder
=
defaultdict
(
list
)
self
.
_guard
=
DataFlowReentrantGuard
()
def
__iter__
(
self
):
def
__iter__
(
self
):
with
self
.
_guard
:
with
self
.
_guard
:
...
@@ -235,7 +235,6 @@ class FixedSizeData(ProxyDataFlow):
...
@@ -235,7 +235,6 @@ class FixedSizeData(ProxyDataFlow):
super
(
FixedSizeData
,
self
)
.
__init__
(
ds
)
super
(
FixedSizeData
,
self
)
.
__init__
(
ds
)
self
.
_size
=
int
(
size
)
self
.
_size
=
int
(
size
)
self
.
itr
=
None
self
.
itr
=
None
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
_keep
=
keep_state
self
.
_keep
=
keep_state
def
__len__
(
self
):
def
__len__
(
self
):
...
@@ -244,6 +243,7 @@ class FixedSizeData(ProxyDataFlow):
...
@@ -244,6 +243,7 @@ class FixedSizeData(ProxyDataFlow):
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
FixedSizeData
,
self
)
.
reset_state
()
super
(
FixedSizeData
,
self
)
.
reset_state
()
self
.
itr
=
self
.
ds
.
__iter__
()
self
.
itr
=
self
.
ds
.
__iter__
()
self
.
_guard
=
DataFlowReentrantGuard
()
def
__iter__
(
self
):
def
__iter__
(
self
):
with
self
.
_guard
:
with
self
.
_guard
:
...
@@ -625,9 +625,9 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
...
@@ -625,9 +625,9 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self
.
shuffle_interval
=
shuffle_interval
self
.
shuffle_interval
=
shuffle_interval
self
.
nr_reuse
=
nr_reuse
self
.
nr_reuse
=
nr_reuse
self
.
_inf_ds
=
RepeatedData
(
ds
,
-
1
)
self
.
_inf_ds
=
RepeatedData
(
ds
,
-
1
)
self
.
_guard
=
DataFlowReentrantGuard
()
def
reset_state
(
self
):
def
reset_state
(
self
):
self
.
_guard
=
DataFlowReentrantGuard
()
ProxyDataFlow
.
reset_state
(
self
)
ProxyDataFlow
.
reset_state
(
self
)
RNGDataFlow
.
reset_state
(
self
)
RNGDataFlow
.
reset_state
(
self
)
self
.
_iter_cnt
=
0
self
.
_iter_cnt
=
0
...
@@ -664,11 +664,11 @@ class CacheData(ProxyDataFlow):
...
@@ -664,11 +664,11 @@ class CacheData(ProxyDataFlow):
shuffle (bool): whether to shuffle the datapoints before producing them.
shuffle (bool): whether to shuffle the datapoints before producing them.
"""
"""
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
_guard
=
DataFlowReentrantGuard
()
super
(
CacheData
,
self
)
.
__init__
(
ds
)
super
(
CacheData
,
self
)
.
__init__
(
ds
)
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
CacheData
,
self
)
.
reset_state
()
super
(
CacheData
,
self
)
.
reset_state
()
self
.
_guard
=
DataFlowReentrantGuard
()
if
self
.
shuffle
:
if
self
.
shuffle
:
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
self
.
buffer
=
[]
self
.
buffer
=
[]
...
...
tensorpack/dataflow/format.py
View file @
f6313a07
...
@@ -90,7 +90,6 @@ class LMDBData(RNGDataFlow):
...
@@ -90,7 +90,6 @@ class LMDBData(RNGDataFlow):
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
self
.
_set_keys
(
keys
)
self
.
_set_keys
(
keys
)
logger
.
info
(
"Found {} entries in {}"
.
format
(
self
.
_size
,
self
.
_lmdb_path
))
logger
.
info
(
"Found {} entries in {}"
.
format
(
self
.
_size
,
self
.
_lmdb_path
))
self
.
_guard
=
DataFlowReentrantGuard
()
def
_set_keys
(
self
,
keys
=
None
):
def
_set_keys
(
self
,
keys
=
None
):
def
find_keys
(
txn
,
size
):
def
find_keys
(
txn
,
size
):
...
@@ -128,6 +127,7 @@ class LMDBData(RNGDataFlow):
...
@@ -128,6 +127,7 @@ class LMDBData(RNGDataFlow):
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
_txn
=
self
.
_lmdb
.
begin
()
def
reset_state
(
self
):
def
reset_state
(
self
):
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
_lmdb
.
close
()
self
.
_lmdb
.
close
()
super
(
LMDBData
,
self
)
.
reset_state
()
super
(
LMDBData
,
self
)
.
reset_state
()
self
.
_open_lmdb
()
self
.
_open_lmdb
()
...
...
tensorpack/dataflow/parallel.py
View file @
f6313a07
...
@@ -306,7 +306,6 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
...
@@ -306,7 +306,6 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
self
.
nr_proc
=
nr_proc
self
.
nr_proc
=
nr_proc
self
.
_hwm
=
hwm
self
.
_hwm
=
hwm
self
.
_guard
=
DataFlowReentrantGuard
()
if
nr_proc
>
1
:
if
nr_proc
>
1
:
logger
.
info
(
"[PrefetchDataZMQ] Will fork a dataflow more than one times. "
logger
.
info
(
"[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d."
)
"This assumes the datapoints are i.i.d."
)
...
@@ -330,6 +329,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
...
@@ -330,6 +329,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
PrefetchDataZMQ
,
self
)
.
reset_state
()
super
(
PrefetchDataZMQ
,
self
)
.
reset_state
()
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
context
=
zmq
.
Context
()
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
socket
.
set_hwm
(
self
.
_hwm
)
self
.
socket
.
set_hwm
(
self
.
_hwm
)
...
...
tensorpack/dataflow/parallel_map.py
View file @
f6313a07
...
@@ -279,11 +279,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
...
@@ -279,11 +279,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self
.
map_func
=
map_func
self
.
map_func
=
map_func
self
.
_strict
=
strict
self
.
_strict
=
strict
self
.
_procs
=
[]
self
.
_procs
=
[]
self
.
_guard
=
DataFlowReentrantGuard
()
def
reset_state
(
self
):
def
reset_state
(
self
):
_MultiProcessZMQDataFlow
.
reset_state
(
self
)
_MultiProcessZMQDataFlow
.
reset_state
(
self
)
_ParallelMapData
.
reset_state
(
self
)
_ParallelMapData
.
reset_state
(
self
)
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
context
=
zmq
.
Context
()
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
DEALER
)
self
.
socket
=
self
.
context
.
socket
(
zmq
.
DEALER
)
...
@@ -369,7 +369,6 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
...
@@ -369,7 +369,6 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
processes
=
nr_proc
,
processes
=
nr_proc
,
initializer
=
_init_pool
,
initializer
=
_init_pool
,
initargs
=
(
self
.
_shared_mem
,
id_queue
,
map_func
))
initargs
=
(
self
.
_shared_mem
,
id_queue
,
map_func
))
self
.
_guard
=
DataFlowReentrantGuard
()
def
_create_shared_arr
(
self
):
def
_create_shared_arr
(
self
):
TYPE
=
{
TYPE
=
{
...
@@ -388,6 +387,7 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
...
@@ -388,6 +387,7 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
def
reset_state
(
self
):
def
reset_state
(
self
):
self
.
ds
.
reset_state
()
self
.
ds
.
reset_state
()
self
.
_guard
=
DataFlowReentrantGuard
()
def
__iter__
(
self
):
def
__iter__
(
self
):
ds_itr
=
_repeat_iter
(
self
.
ds
.
get_data
)
ds_itr
=
_repeat_iter
(
self
.
ds
.
get_data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment