Commit 27841032 authored by Yuxin Wu's avatar Yuxin Wu

add TFRecordData (#174)

parent f3d290cc
......@@ -129,13 +129,6 @@ def dump_dataflow_to_lmdb(df, lmdb_path, write_frequency=5000):
db.close()
from ..utils.develop import create_dummy_func # noqa
try:
import lmdb
except ImportError:
dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa
def dump_dataflow_to_tfrecord(df, path):
"""
Dump all datapoints of a Dataflow to a TensorFlow TFRecord file,
......@@ -151,6 +144,12 @@ def dump_dataflow_to_tfrecord(df, path):
writer.write(dumps(dp))
from ..utils.develop import create_dummy_func # noqa
try:
import lmdb
except ImportError:
dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa
try:
import tensorflow as tf
except ImportError:
......
......@@ -12,11 +12,11 @@ from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from ..utils.serialize import loads
from ..utils.argtools import log_once
from .base import RNGDataFlow
from .base import RNGDataFlow, DataFlow
from .common import MapData
__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint',
'CaffeLMDB', 'SVMLightData']
'CaffeLMDB', 'SVMLightData', 'TFRecordData']
"""
Adapters for different data format.
......@@ -228,6 +228,25 @@ class SVMLightData(RNGDataFlow):
yield [self.X[id, :], self.y[id]]
class TFRecordData(DataFlow):
"""
Produce datapoints from a TFRecord file, assuming each record is
serialized by :func:`serialize.dumps`.
This class works with :func:`dftools.dump_dataflow_to_tfrecord`.
"""
def __init__(self, path, size=None):
self._gen = tf.python_io.tf_record_iterator(path)
self._size = size
def size(self):
if self._size:
return self._size
return super(TFRecordData, self).size()
def get_data(self):
for dp in self._gen:
yield loads(dp)
from ..utils.develop import create_dummy_class # noqa
try:
import h5py
......@@ -244,3 +263,8 @@ try:
import sklearn.datasets
except ImportError:
SVMLightData = create_dummy_class('SVMLightData', 'sklearn') # noqa
try:
import tensorflow as tf
except ImportError:
TFRecordData = create_dummy_class('TFRecordData', 'tensorflow') # noqa
......@@ -81,3 +81,22 @@ class DataFromList(RNGDataFlow):
self.rng.shuffle(idxs)
for k in idxs:
yield self.lst[k]
class DataFromGenerator(DataFlow):
"""
Wrap a generator to a DataFlow
"""
def __init__(self, gen, size=None):
self._gen = gen
self._size = size
def size(self):
if self._size:
return self._size
return super(DataFromGenerator, self).size()
def get_data(self):
# yield from
for dp in self._gen:
yield dp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment