Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8bdd9c85
Commit
8bdd9c85
authored
May 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
speed up prefetch
parent
c59586b2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
41 deletions
+41
-41
examples/cifar10-convnet.py
examples/cifar10-convnet.py
+1
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+2
-1
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+1
-1
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+29
-30
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+8
-8
No files found.
examples/cifar10-convnet.py
View file @
8bdd9c85
...
...
@@ -97,7 +97,7 @@ def get_data(train_or_test):
ds
=
AugmentImageComponent
(
ds
,
augmentors
)
ds
=
BatchData
(
ds
,
128
,
remainder
=
not
isTrain
)
if
isTrain
:
ds
=
PrefetchData
(
ds
,
10
,
5
)
ds
=
PrefetchData
ZMQ
(
ds
,
5
)
return
ds
def
get_config
():
...
...
tensorpack/dataflow/common.py
View file @
8bdd9c85
...
...
@@ -145,7 +145,8 @@ class FakeData(DataFlow):
def
get_data
(
self
):
for
_
in
range
(
self
.
_size
):
yield
[
self
.
rng
.
random_sample
(
k
)
for
k
in
self
.
shapes
]
yield
[
self
.
rng
.
random_sample
(
k
)
.
astype
(
'float32'
)
for
k
in
self
.
shapes
]
#yield [self.rng.random_sample(k) for k in self.shapes]
class
MapData
(
ProxyDataFlow
):
""" Apply map/filter a function on the datapoint"""
...
...
tensorpack/dataflow/format.py
View file @
8bdd9c85
...
...
@@ -11,7 +11,7 @@ from six.moves import range
try
:
import
h5py
except
ImportError
:
logger
.
error
(
"Error in 'import h5py'. HDF5Data won't be
imported
."
)
logger
.
error
(
"Error in 'import h5py'. HDF5Data won't be
available
."
)
__all__
=
[]
else
:
__all__
=
[
'HDF5Data'
]
...
...
tensorpack/dataflow/prefetch.py
View file @
8bdd9c85
...
...
@@ -7,7 +7,6 @@ from threading import Thread
from
six.moves
import
range
from
six.moves.queue
import
Queue
import
uuid
import
zmq
import
os
from
.base
import
ProxyDataFlow
...
...
@@ -15,7 +14,14 @@ from ..utils.concurrency import ensure_proc_terminate
from
..utils.serialize
import
*
from
..utils
import
logger
__all__
=
[
'PrefetchData'
,
'PrefetchDataZMQ'
]
try
:
import
zmq
except
ImportError
:
logger
.
error
(
"Error in 'import zmq'. PrefetchDataZMQ won't be available."
)
__all__
=
[
'PrefetchData'
]
else
:
__all__
=
[
'PrefetchData'
,
'PrefetchDataZMQ'
]
class
PrefetchProcess
(
multiprocessing
.
Process
):
def
__init__
(
self
,
ds
,
queue
):
...
...
@@ -69,64 +75,57 @@ class PrefetchData(ProxyDataFlow):
logger
.
info
(
"Prefetch process exited."
)
class
PrefetchProcessZMQ
(
multiprocessing
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
,
qsize
=
1
):
def
__init__
(
self
,
ds
,
conn_name
):
"""
:param ds: a `DataFlow` instance.
:param conn_name: the name of the IPC connection
"""
super
(
PrefetchProcessZMQ
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
qsize
=
qsize
self
.
conn_name
=
conn_name
def
run
(
self
):
self
.
ds
.
reset_state
()
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PUSH
)
self
.
socket
.
set_hwm
(
self
.
qsize
)
self
.
socket
.
set_hwm
(
1
)
self
.
socket
.
connect
(
self
.
conn_name
)
self
.
id
=
os
.
getpid
()
cnt
=
0
while
True
:
for
dp
in
self
.
ds
.
get_data
():
self
.
socket
.
send
(
dumps
(
dp
))
cnt
+=
1
print
(
"Proc {} send {}"
.
format
(
self
.
id
,
cnt
))
self
.
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
class
PrefetchDataZMQ
(
ProxyDataFlow
):
""" Work the same as `PrefetchData`, but faster. """
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
def
__init__
(
self
,
ds
,
nr_proc
=
1
):
"""
:param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order
of datapoints will be random.
"""
super
(
PrefetchDataZMQ
,
self
)
.
__init__
(
ds
)
self
.
ds
=
ds
self
.
_size
=
ds
.
size
()
self
.
nr_proc
=
nr_proc
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
name
=
"ipc://whatever-"
+
str
(
uuid
.
uuid1
())[:
6
]
self
.
socket
.
bind
(
name
)
self
.
pipename
=
"ipc://dataflow-pipe-"
+
str
(
uuid
.
uuid1
())[:
6
]
self
.
socket
.
set_hwm
(
5
)
# a little bit faster than default, don't know why
self
.
socket
.
bind
(
self
.
pipename
)
# TODO local queue again? probably don't affect training
self
.
queue
=
Queue
(
maxsize
=
nr_prefetch
)
def
enque
():
while
True
:
self
.
queue
.
put
(
loads
(
self
.
socket
.
recv
(
copy
=
False
)))
self
.
th
=
Thread
(
target
=
enque
)
self
.
th
.
daemon
=
True
self
.
th
.
start
()
self
.
procs
=
[
PrefetchProcessZMQ
(
self
.
ds
,
name
)
self
.
procs
=
[
PrefetchProcessZMQ
(
self
.
ds
,
self
.
pipename
)
for
_
in
range
(
self
.
nr_proc
)]
for
x
in
self
.
procs
:
x
.
start
()
def
get_data
(
self
):
for
_
in
range
(
self
.
_size
):
dp
=
self
.
queue
.
get
(
)
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
)
)
yield
dp
#print(self.queue.qsize())
def
__del__
(
self
):
logger
.
info
(
"Prefetch process exiting..."
)
self
.
queue
.
close
(
)
self
.
context
.
destroy
(
0
)
for
x
in
self
.
procs
:
x
.
terminate
()
self
.
th
.
terminate
()
logger
.
info
(
"Prefetch process exited."
)
tensorpack/utils/serialize.py
View file @
8bdd9c85
...
...
@@ -3,17 +3,17 @@
# File: serialize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
msgpack
import
msgpack_numpy
msgpack_numpy
.
patch
()
#
import dill
#
import msgpack
#
import msgpack_numpy
#
msgpack_numpy.patch()
import
dill
__all__
=
[
'loads'
,
'dumps'
]
def
dumps
(
obj
):
#
return dill.dumps(obj)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
return
dill
.
dumps
(
obj
)
#
return msgpack.dumps(obj, use_bin_type=True)
def
loads
(
buf
):
#
return dill.loads(buf)
return
msgpack
.
loads
(
buf
)
return
dill
.
loads
(
buf
)
#
return msgpack.loads(buf)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment