Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f3d290cc
Commit
f3d290cc
authored
Mar 31, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dump to TFRecord (#174)
parent
ee7dcd0d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
70 additions
and
48 deletions
+70
-48
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+2
-2
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+66
-44
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+2
-2
No files found.
tensorpack/dataflow/dataset/cifar.py
View file @
f3d290cc
...
@@ -151,10 +151,10 @@ class Cifar100(CifarBase):
...
@@ -151,10 +151,10 @@ class Cifar100(CifarBase):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
ds
=
Cifar10
(
'train'
)
ds
=
Cifar10
(
'train'
)
from
tensorpack.dataflow.dftools
import
dump_data
set
_images
from
tensorpack.dataflow.dftools
import
dump_data
flow
_images
mean
=
ds
.
get_per_channel_mean
()
mean
=
ds
.
get_per_channel_mean
()
print
(
mean
)
print
(
mean
)
dump_data
set
_images
(
ds
,
'/tmp/cifar'
,
100
)
dump_data
flow
_images
(
ds
,
'/tmp/cifar'
,
100
)
# for (img, label) in ds.get_data():
# for (img, label) in ds.get_data():
# from IPython import embed; embed()
# from IPython import embed; embed()
...
...
tensorpack/dataflow/dftools.py
View file @
f3d290cc
...
@@ -14,15 +14,15 @@ from ..utils.concurrency import DIE
...
@@ -14,15 +14,15 @@ from ..utils.concurrency import DIE
from
..utils.serialize
import
dumps
from
..utils.serialize
import
dumps
from
..utils.fs
import
mkdir_p
from
..utils.fs
import
mkdir_p
__all__
=
[
'dump_data
set_images'
,
'
dataflow_to_process_queue'
,
__all__
=
[
'dump_data
flow_images'
,
'dump_
dataflow_to_process_queue'
,
'dump_dataflow_to_lmdb'
]
'dump_dataflow_to_lmdb'
,
'dump_dataflow_to_tfrecord'
]
def
dump_data
set_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
def
dump_data
flow_images
(
df
,
dirname
,
max_count
=
None
,
index
=
0
):
""" Dump images from a DataFlow to a directory.
""" Dump images from a DataFlow to a directory.
Args:
Args:
d
s
(DataFlow): the DataFlow to dump.
d
f
(DataFlow): the DataFlow to dump.
dirname (str): name of the directory.
dirname (str): name of the directory.
max_count (int): limit max number of images to dump. Defaults to unlimited.
max_count (int): limit max number of images to dump. Defaults to unlimited.
index (int): the index of the image component in the data point.
index (int): the index of the image component in the data point.
...
@@ -31,8 +31,8 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
...
@@ -31,8 +31,8 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
mkdir_p
(
dirname
)
mkdir_p
(
dirname
)
if
max_count
is
None
:
if
max_count
is
None
:
max_count
=
sys
.
maxint
max_count
=
sys
.
maxint
d
s
.
reset_state
()
d
f
.
reset_state
()
for
i
,
dp
in
enumerate
(
d
s
.
get_data
()):
for
i
,
dp
in
enumerate
(
d
f
.
get_data
()):
if
i
%
100
==
0
:
if
i
%
100
==
0
:
print
(
i
)
print
(
i
)
if
i
>
max_count
:
if
i
>
max_count
:
...
@@ -41,7 +41,46 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
...
@@ -41,7 +41,46 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
cv2
.
imwrite
(
os
.
path
.
join
(
dirname
,
"{}.jpg"
.
format
(
i
)),
img
)
cv2
.
imwrite
(
os
.
path
.
join
(
dirname
,
"{}.jpg"
.
format
(
i
)),
img
)
def
dump_dataflow_to_lmdb
(
ds
,
lmdb_path
,
write_frequency
=
5000
):
def
dump_dataflow_to_process_queue
(
df
,
size
,
nr_consumer
):
"""
Convert a DataFlow to a :class:`multiprocessing.Queue`.
The DataFlow will only be reset in the spawned process.
Args:
df (DataFlow): the DataFlow to dump.
size (int): size of the queue
nr_consumer (int): number of consumer of the queue.
The producer will add this many of ``DIE`` sentinel to the end of the queue.
Returns:
tuple(queue, process):
The process will take data from ``df`` and fill
the queue, once you start it. Each element in the queue is (idx,
dp). idx can be the ``DIE`` sentinel when ``df`` is exhausted.
"""
q
=
mp
.
Queue
(
size
)
class
EnqueProc
(
mp
.
Process
):
def
__init__
(
self
,
df
,
q
,
nr_consumer
):
super
(
EnqueProc
,
self
)
.
__init__
()
self
.
df
=
df
self
.
q
=
q
def
run
(
self
):
self
.
df
.
reset_state
()
try
:
for
idx
,
dp
in
enumerate
(
self
.
df
.
get_data
()):
self
.
q
.
put
((
idx
,
dp
))
finally
:
for
_
in
range
(
nr_consumer
):
self
.
q
.
put
((
DIE
,
None
))
proc
=
EnqueProc
(
df
,
q
,
nr_consumer
)
return
q
,
proc
def
dump_dataflow_to_lmdb
(
df
,
lmdb_path
,
write_frequency
=
5000
):
"""
"""
Dump a Dataflow to a lmdb database, where the keys are indices and values
Dump a Dataflow to a lmdb database, where the keys are indices and values
are serialized datapoints.
are serialized datapoints.
...
@@ -49,22 +88,22 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
...
@@ -49,22 +88,22 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
:class:`tensorpack.dataflow.LMDBDataPoint`.
:class:`tensorpack.dataflow.LMDBDataPoint`.
Args:
Args:
d
s
(DataFlow): the DataFlow to dump.
d
f
(DataFlow): the DataFlow to dump.
lmdb_path (str): output path. Either a directory or a mdb file.
lmdb_path (str): output path. Either a directory or a mdb file.
write_frequency (int): the frequency to write back data to disk.
write_frequency (int): the frequency to write back data to disk.
"""
"""
assert
isinstance
(
d
s
,
DataFlow
),
type
(
ds
)
assert
isinstance
(
d
f
,
DataFlow
),
type
(
df
)
isdir
=
os
.
path
.
isdir
(
lmdb_path
)
isdir
=
os
.
path
.
isdir
(
lmdb_path
)
if
isdir
:
if
isdir
:
assert
not
os
.
path
.
isfile
(
os
.
path
.
join
(
lmdb_path
,
'data.mdb'
)),
"LMDB file exists!"
assert
not
os
.
path
.
isfile
(
os
.
path
.
join
(
lmdb_path
,
'data.mdb'
)),
"LMDB file exists!"
else
:
else
:
assert
not
os
.
path
.
isfile
(
lmdb_path
),
"LMDB file exists!"
assert
not
os
.
path
.
isfile
(
lmdb_path
),
"LMDB file exists!"
d
s
.
reset_state
()
d
f
.
reset_state
()
db
=
lmdb
.
open
(
lmdb_path
,
subdir
=
isdir
,
db
=
lmdb
.
open
(
lmdb_path
,
subdir
=
isdir
,
map_size
=
1099511627776
*
2
,
readonly
=
False
,
map_size
=
1099511627776
*
2
,
readonly
=
False
,
meminit
=
False
,
map_async
=
True
)
# need sync() at the end
meminit
=
False
,
map_async
=
True
)
# need sync() at the end
try
:
try
:
sz
=
d
s
.
size
()
sz
=
d
f
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
sz
=
0
sz
=
0
with
get_tqdm
(
total
=
sz
)
as
pbar
:
with
get_tqdm
(
total
=
sz
)
as
pbar
:
...
@@ -73,7 +112,7 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
...
@@ -73,7 +112,7 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
# lmdb transaction is not exception-safe!
# lmdb transaction is not exception-safe!
# although it has a contextmanager interface
# although it has a contextmanager interface
txn
=
db
.
begin
(
write
=
True
)
txn
=
db
.
begin
(
write
=
True
)
for
idx
,
dp
in
enumerate
(
d
s
.
get_data
()):
for
idx
,
dp
in
enumerate
(
d
f
.
get_data
()):
txn
.
put
(
u'{}'
.
format
(
idx
)
.
encode
(
'ascii'
),
dumps
(
dp
))
txn
.
put
(
u'{}'
.
format
(
idx
)
.
encode
(
'ascii'
),
dumps
(
dp
))
pbar
.
update
()
pbar
.
update
()
if
(
idx
+
1
)
%
write_frequency
==
0
:
if
(
idx
+
1
)
%
write_frequency
==
0
:
...
@@ -90,47 +129,30 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
...
@@ -90,47 +129,30 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
db
.
close
()
db
.
close
()
from
..utils.develop
import
create_dummy_func
# noqa
try
:
try
:
import
lmdb
import
lmdb
except
ImportError
:
except
ImportError
:
from
..utils.develop
import
create_dummy_func
dump_dataflow_to_lmdb
=
create_dummy_func
(
'dump_dataflow_to_lmdb'
,
'lmdb'
)
# noqa
dump_dataflow_to_lmdb
=
create_dummy_func
(
'dump_dataflow_to_lmdb'
,
'lmdb'
)
# noqa
def
d
ataflow_to_process_queue
(
ds
,
size
,
nr_consumer
):
def
d
ump_dataflow_to_tfrecord
(
df
,
path
):
"""
"""
Convert a DataFlow to a :class:`multiprocessing.Queue`.
Dump all datapoints of a Dataflow to a TensorFlow TFRecord file,
The DataFlow will only be reset in the spawned process
.
using :func:`serialize.dumps` to serialize
.
Args:
Args:
ds (DataFlow): the DataFlow to dump.
df (DataFlow):
size (int): size of the queue
path (str): the output file path
nr_consumer (int): number of consumer of the queue.
The producer will add this many of ``DIE`` sentinel to the end of the queue.
Returns:
tuple(queue, process):
The process will take data from ``ds`` and fill
the queue, once you start it. Each element in the queue is (idx,
dp). idx can be the ``DIE`` sentinel when ``ds`` is exhausted.
"""
"""
q
=
mp
.
Queue
(
size
)
df
.
reset_state
()
with
tf
.
python_io
.
TFRecordWriter
(
path
)
as
writer
:
for
dp
in
df
.
get_data
():
writer
.
write
(
dumps
(
dp
))
class
EnqueProc
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
q
,
nr_consumer
):
try
:
super
(
EnqueProc
,
self
)
.
__init__
()
import
tensorflow
as
tf
self
.
ds
=
ds
except
ImportError
:
self
.
q
=
q
dump_dataflow_to_tfrecord
=
create_dummy_func
(
# noqa
'dump_dataflow_to_tfrecord'
,
'tensorflow'
)
def
run
(
self
):
self
.
ds
.
reset_state
()
try
:
for
idx
,
dp
in
enumerate
(
self
.
ds
.
get_data
()):
self
.
q
.
put
((
idx
,
dp
))
finally
:
for
_
in
range
(
nr_consumer
):
self
.
q
.
put
((
DIE
,
None
))
proc
=
EnqueProc
(
ds
,
q
,
nr_consumer
)
return
q
,
proc
tensorpack/predict/dataset.py
View file @
f3d290cc
...
@@ -10,7 +10,7 @@ import os
...
@@ -10,7 +10,7 @@ import os
import
six
import
six
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..dataflow.dftools
import
dataflow_to_process_queue
from
..dataflow.dftools
import
d
ump_d
ataflow_to_process_queue
from
..utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
..utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
..utils
import
logger
,
get_tqdm
from
..utils
import
logger
,
get_tqdm
from
..utils.gpu
import
change_gpu
from
..utils.gpu
import
change_gpu
...
@@ -105,7 +105,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
...
@@ -105,7 +105,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self
.
nr_proc
=
nr_proc
self
.
nr_proc
=
nr_proc
self
.
ordered
=
ordered
self
.
ordered
=
ordered
self
.
inqueue
,
self
.
inqueue_proc
=
dataflow_to_process_queue
(
self
.
inqueue
,
self
.
inqueue_proc
=
d
ump_d
ataflow_to_process_queue
(
self
.
dataset
,
nr_proc
*
2
,
self
.
nr_proc
)
# put (idx, dp) to inqueue
self
.
dataset
,
nr_proc
*
2
,
self
.
nr_proc
)
# put (idx, dp) to inqueue
if
use_gpu
:
if
use_gpu
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment