Commit f3d290cc authored by Yuxin Wu's avatar Yuxin Wu

dump to TFRecord (#174)

parent ee7dcd0d
...@@ -151,10 +151,10 @@ class Cifar100(CifarBase): ...@@ -151,10 +151,10 @@ class Cifar100(CifarBase):
if __name__ == '__main__': if __name__ == '__main__':
ds = Cifar10('train') ds = Cifar10('train')
from tensorpack.dataflow.dftools import dump_dataset_images from tensorpack.dataflow.dftools import dump_dataflow_images
mean = ds.get_per_channel_mean() mean = ds.get_per_channel_mean()
print(mean) print(mean)
dump_dataset_images(ds, '/tmp/cifar', 100) dump_dataflow_images(ds, '/tmp/cifar', 100)
# for (img, label) in ds.get_data(): # for (img, label) in ds.get_data():
# from IPython import embed; embed() # from IPython import embed; embed()
......
...@@ -14,15 +14,15 @@ from ..utils.concurrency import DIE ...@@ -14,15 +14,15 @@ from ..utils.concurrency import DIE
from ..utils.serialize import dumps from ..utils.serialize import dumps
from ..utils.fs import mkdir_p from ..utils.fs import mkdir_p
__all__ = ['dump_dataset_images', 'dataflow_to_process_queue', __all__ = ['dump_dataflow_images', 'dump_dataflow_to_process_queue',
'dump_dataflow_to_lmdb'] 'dump_dataflow_to_lmdb', 'dump_dataflow_to_tfrecord']
def dump_dataset_images(ds, dirname, max_count=None, index=0): def dump_dataflow_images(df, dirname, max_count=None, index=0):
""" Dump images from a DataFlow to a directory. """ Dump images from a DataFlow to a directory.
Args: Args:
ds (DataFlow): the DataFlow to dump. df (DataFlow): the DataFlow to dump.
dirname (str): name of the directory. dirname (str): name of the directory.
max_count (int): limit max number of images to dump. Defaults to unlimited. max_count (int): limit max number of images to dump. Defaults to unlimited.
index (int): the index of the image component in the data point. index (int): the index of the image component in the data point.
...@@ -31,8 +31,8 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0): ...@@ -31,8 +31,8 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
mkdir_p(dirname) mkdir_p(dirname)
if max_count is None: if max_count is None:
max_count = sys.maxint max_count = sys.maxint
ds.reset_state() df.reset_state()
for i, dp in enumerate(ds.get_data()): for i, dp in enumerate(df.get_data()):
if i % 100 == 0: if i % 100 == 0:
print(i) print(i)
if i > max_count: if i > max_count:
...@@ -41,7 +41,46 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0): ...@@ -41,7 +41,46 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img) cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000): def dump_dataflow_to_process_queue(df, size, nr_consumer):
"""
Convert a DataFlow to a :class:`multiprocessing.Queue`.
The DataFlow will only be reset in the spawned process.
Args:
df (DataFlow): the DataFlow to dump.
size (int): size of the queue
nr_consumer (int): number of consumer of the queue.
The producer will add this many of ``DIE`` sentinel to the end of the queue.
Returns:
tuple(queue, process):
The process will take data from ``df`` and fill
the queue, once you start it. Each element in the queue is (idx,
dp). idx can be the ``DIE`` sentinel when ``df`` is exhausted.
"""
q = mp.Queue(size)
class EnqueProc(mp.Process):
def __init__(self, df, q, nr_consumer):
super(EnqueProc, self).__init__()
self.df = df
self.q = q
def run(self):
self.df.reset_state()
try:
for idx, dp in enumerate(self.df.get_data()):
self.q.put((idx, dp))
finally:
for _ in range(nr_consumer):
self.q.put((DIE, None))
proc = EnqueProc(df, q, nr_consumer)
return q, proc
def dump_dataflow_to_lmdb(df, lmdb_path, write_frequency=5000):
""" """
Dump a Dataflow to a lmdb database, where the keys are indices and values Dump a Dataflow to a lmdb database, where the keys are indices and values
are serialized datapoints. are serialized datapoints.
...@@ -49,22 +88,22 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000): ...@@ -49,22 +88,22 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
:class:`tensorpack.dataflow.LMDBDataPoint`. :class:`tensorpack.dataflow.LMDBDataPoint`.
Args: Args:
ds (DataFlow): the DataFlow to dump. df (DataFlow): the DataFlow to dump.
lmdb_path (str): output path. Either a directory or a mdb file. lmdb_path (str): output path. Either a directory or a mdb file.
write_frequency (int): the frequency to write back data to disk. write_frequency (int): the frequency to write back data to disk.
""" """
assert isinstance(ds, DataFlow), type(ds) assert isinstance(df, DataFlow), type(df)
isdir = os.path.isdir(lmdb_path) isdir = os.path.isdir(lmdb_path)
if isdir: if isdir:
assert not os.path.isfile(os.path.join(lmdb_path, 'data.mdb')), "LMDB file exists!" assert not os.path.isfile(os.path.join(lmdb_path, 'data.mdb')), "LMDB file exists!"
else: else:
assert not os.path.isfile(lmdb_path), "LMDB file exists!" assert not os.path.isfile(lmdb_path), "LMDB file exists!"
ds.reset_state() df.reset_state()
db = lmdb.open(lmdb_path, subdir=isdir, db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False, map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True) # need sync() at the end meminit=False, map_async=True) # need sync() at the end
try: try:
sz = ds.size() sz = df.size()
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with get_tqdm(total=sz) as pbar: with get_tqdm(total=sz) as pbar:
...@@ -73,7 +112,7 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000): ...@@ -73,7 +112,7 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
# lmdb transaction is not exception-safe! # lmdb transaction is not exception-safe!
# although it has a contextmanager interface # although it has a contextmanager interface
txn = db.begin(write=True) txn = db.begin(write=True)
for idx, dp in enumerate(ds.get_data()): for idx, dp in enumerate(df.get_data()):
txn.put(u'{}'.format(idx).encode('ascii'), dumps(dp)) txn.put(u'{}'.format(idx).encode('ascii'), dumps(dp))
pbar.update() pbar.update()
if (idx + 1) % write_frequency == 0: if (idx + 1) % write_frequency == 0:
...@@ -90,47 +129,30 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000): ...@@ -90,47 +129,30 @@ def dump_dataflow_to_lmdb(ds, lmdb_path, write_frequency=5000):
db.close() db.close()
from ..utils.develop import create_dummy_func # noqa
try: try:
import lmdb import lmdb
except ImportError: except ImportError:
from ..utils.develop import create_dummy_func
dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa
def dataflow_to_process_queue(ds, size, nr_consumer): def dump_dataflow_to_tfrecord(df, path):
""" """
Convert a DataFlow to a :class:`multiprocessing.Queue`. Dump all datapoints of a Dataflow to a TensorFlow TFRecord file,
The DataFlow will only be reset in the spawned process. using :func:`serialize.dumps` to serialize.
Args: Args:
ds (DataFlow): the DataFlow to dump. df (DataFlow):
size (int): size of the queue path (str): the output file path
nr_consumer (int): number of consumer of the queue.
The producer will add this many of ``DIE`` sentinel to the end of the queue.
Returns:
tuple(queue, process):
The process will take data from ``ds`` and fill
the queue, once you start it. Each element in the queue is (idx,
dp). idx can be the ``DIE`` sentinel when ``ds`` is exhausted.
""" """
q = mp.Queue(size) df.reset_state()
with tf.python_io.TFRecordWriter(path) as writer:
for dp in df.get_data():
writer.write(dumps(dp))
class EnqueProc(mp.Process):
def __init__(self, ds, q, nr_consumer): try:
super(EnqueProc, self).__init__() import tensorflow as tf
self.ds = ds except ImportError:
self.q = q dump_dataflow_to_tfrecord = create_dummy_func( # noqa
'dump_dataflow_to_tfrecord', 'tensorflow')
def run(self):
self.ds.reset_state()
try:
for idx, dp in enumerate(self.ds.get_data()):
self.q.put((idx, dp))
finally:
for _ in range(nr_consumer):
self.q.put((DIE, None))
proc = EnqueProc(ds, q, nr_consumer)
return q, proc
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
import six import six
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..dataflow.dftools import dataflow_to_process_queue from ..dataflow.dftools import dump_dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger, get_tqdm from ..utils import logger, get_tqdm
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
...@@ -105,7 +105,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -105,7 +105,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.ordered = ordered self.ordered = ordered
self.inqueue, self.inqueue_proc = dataflow_to_process_queue( self.inqueue, self.inqueue_proc = dump_dataflow_to_process_queue(
self.dataset, nr_proc * 2, self.nr_proc) # put (idx, dp) to inqueue self.dataset, nr_proc * 2, self.nr_proc) # put (idx, dp) to inqueue
if use_gpu: if use_gpu:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment