Commit f3d290cc authored by Yuxin Wu's avatar Yuxin Wu

dump to TFRecord (#174)

parent ee7dcd0d
......@@ -151,10 +151,10 @@ class Cifar100(CifarBase):
if __name__ == '__main__':
ds = Cifar10('train')
from tensorpack.dataflow.dftools import dump_dataset_images
from tensorpack.dataflow.dftools import dump_dataflow_images
mean = ds.get_per_channel_mean()
print(mean)
dump_dataset_images(ds, '/tmp/cifar', 100)
dump_dataflow_images(ds, '/tmp/cifar', 100)
# for (img, label) in ds.get_data():
# from IPython import embed; embed()
......
......@@ -14,15 +14,15 @@ from ..utils.concurrency import DIE
from ..utils.serialize import dumps
from ..utils.fs import mkdir_p
__all__ = ['dump_dataset_images', 'dataflow_to_process_queue',
'dump_dataflow_to_lmdb']
__all__ = ['dump_dataflow_images', 'dump_dataflow_to_process_queue',
'dump_dataflow_to_lmdb', 'dump_dataflow_to_tfrecord']
def dump_dataset_images(ds, dirname, max_count=None, index=0):
def dump_dataflow_images(df, dirname, max_count=None, index=0):
""" Dump images from a DataFlow to a directory.
Args:
ds (DataFlow): the DataFlow to dump.
df (DataFlow): the DataFlow to dump.
dirname (str): name of the directory.
max_count (int): limit max number of images to dump. Defaults to unlimited.
index (int): the index of the image component in the data point.
......@@ -31,8 +31,8 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
mkdir_p(dirname)
if max_count is None:
max_count = sys.maxint
ds.reset_state()
for i, dp in enumerate(ds.get_data()):
df.reset_state()
for i, dp in enumerate(df.get_data()):
if i % 100 == 0:
print(i)
if i > max_count:
......@@ -41,7 +41,46 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
def dump_dataflow_to_process_queue(df, size, nr_consumer):
"""
Convert a DataFlow to a :class:`multiprocessing.Queue`.
The DataFlow will only be reset in the spawned process.
Args:
df (DataFlow): the DataFlow to dump.
size (int): size of the queue
nr_consumer (int): number of consumer of the queue.
The producer will add this many of ``DIE`` sentinel to the end of the queue.
Returns:
tuple(queue, process):
The process will take data from ``df`` and fill
the queue, once you start it. Each element in the queue is (idx,
dp). idx can be the ``DIE`` sentinel when ``df`` is exhausted.
"""
q = mp.Queue(size)
class EnqueProc(mp.Process):
def __init__(self, df, q, nr_consumer):
super(EnqueProc, self).__init__()
self.df = df
self.q = q
def run(self):
self.df.reset_state()
try:
for idx, dp in enumerate(self.df.get_data()):
self.q.put((idx, dp))
finally:
for _ in range(nr_consumer):
self.q.put((DIE, None))
proc = EnqueProc(df, q, nr_consumer)
return q, proc
def dump_dataflow_to_lmdb(df, lmdb_path, write_frequency=5000):
"""
Dump a Dataflow to a lmdb database, where the keys are indices and values
are serialized datapoints.
......@@ -49,22 +88,22 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
:class:`tensorpack.dataflow.LMDBDataPoint`.
Args:
ds (DataFlow): the DataFlow to dump.
df (DataFlow): the DataFlow to dump.
lmdb_path (str): output path. Either a directory or a mdb file.
write_frequency (int): the frequency to write back data to disk.
"""
assert isinstance(ds, DataFlow), type(ds)
assert isinstance(df, DataFlow), type(df)
isdir = os.path.isdir(lmdb_path)
if isdir:
assert not os.path.isfile(os.path.join(lmdb_path, 'data.mdb')), "LMDB file exists!"
else:
assert not os.path.isfile(lmdb_path), "LMDB file exists!"
ds.reset_state()
df.reset_state()
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True) # need sync() at the end
try:
sz = ds.size()
sz = df.size()
except NotImplementedError:
sz = 0
with get_tqdm(total=sz) as pbar:
......@@ -73,7 +112,7 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
# lmdb transaction is not exception-safe!
# although it has a contextmanager interface
txn = db.begin(write=True)
for idx, dp in enumerate(ds.get_data()):
for idx, dp in enumerate(df.get_data()):
txn.put(u'{}'.format(idx).encode('ascii'), dumps(dp))
pbar.update()
if (idx + 1) % write_frequency == 0:
......@@ -90,47 +129,30 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
db.close()
from ..utils.develop import create_dummy_func # noqa
try:
import lmdb
except ImportError:
from ..utils.develop import create_dummy_func
dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa
def dataflow_to_process_queue(ds, size, nr_consumer):
def dump_dataflow_to_tfrecord(df, path):
"""
Convert a DataFlow to a :class:`multiprocessing.Queue`.
The DataFlow will only be reset in the spawned process.
Dump all datapoints of a Dataflow to a TensorFlow TFRecord file,
using :func:`serialize.dumps` to serialize.
Args:
ds (DataFlow): the DataFlow to dump.
size (int): size of the queue
nr_consumer (int): number of consumer of the queue.
The producer will add this many of ``DIE`` sentinel to the end of the queue.
Returns:
tuple(queue, process):
The process will take data from ``ds`` and fill
the queue, once you start it. Each element in the queue is (idx,
dp). idx can be the ``DIE`` sentinel when ``ds`` is exhausted.
df (DataFlow):
path (str): the output file path
"""
q = mp.Queue(size)
df.reset_state()
with tf.python_io.TFRecordWriter(path) as writer:
for dp in df.get_data():
writer.write(dumps(dp))
class EnqueProc(mp.Process):
def __init__(self, ds, q, nr_consumer):
super(EnqueProc, self).__init__()
self.ds = ds
self.q = q
def run(self):
self.ds.reset_state()
try:
for idx, dp in enumerate(self.ds.get_data()):
self.q.put((idx, dp))
finally:
for _ in range(nr_consumer):
self.q.put((DIE, None))
proc = EnqueProc(ds, q, nr_consumer)
return q, proc
try:
import tensorflow as tf
except ImportError:
dump_dataflow_to_tfrecord = create_dummy_func( # noqa
'dump_dataflow_to_tfrecord', 'tensorflow')
......@@ -10,7 +10,7 @@ import os
import six
from ..dataflow import DataFlow
from ..dataflow.dftools import dataflow_to_process_queue
from ..dataflow.dftools import dump_dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger, get_tqdm
from ..utils.gpu import change_gpu
......@@ -105,7 +105,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self.nr_proc = nr_proc
self.ordered = ordered
self.inqueue, self.inqueue_proc = dataflow_to_process_queue(
self.inqueue, self.inqueue_proc = dump_dataflow_to_process_queue(
self.dataset, nr_proc * 2, self.nr_proc) # put (idx, dp) to inqueue
if use_gpu:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment