Commit 7eb08df1 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Add unit test for Serializer (#830)

* add unit test for serializer and complete new interface for hdf5
files

* test both versions msgpack and pyarrow

* use np.asarray

* changes from review

* update
parent 536c7587
...@@ -41,7 +41,7 @@ matrix: ...@@ -41,7 +41,7 @@ matrix:
install: install:
- pip install -U pip # the pip version on travis is too old - pip install -U pip # the pip version on travis is too old
- pip install flake8 scikit-image opencv-python - pip install flake8 scikit-image opencv-python lmdb h5py pyarrow msgpack
- pip install . - pip install .
# check that dataflow can be imported alone # check that dataflow can be imported alone
- python -c "import tensorpack.dataflow" - python -c "import tensorpack.dataflow"
......
...@@ -3,17 +3,18 @@ ...@@ -3,17 +3,18 @@
import os import os
import numpy as np import numpy as np
from collections import defaultdict
from ..utils.utils import get_tqdm from ..utils.utils import get_tqdm
from ..utils import logger from ..utils import logger
from ..utils.serialize import dumps, loads from ..utils.serialize import dumps, loads
from .base import DataFlow from .base import DataFlow
from .format import LMDBData from .format import LMDBData, HDF5Data
from .common import MapData, FixedSizeData from .common import MapData, FixedSizeData
from .raw import DataFromList, DataFromGenerator from .raw import DataFromList, DataFromGenerator
__all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer'] __all__ = ['LMDBSerializer', 'NumpySerializer', 'TFRecordSerializer', 'HDF5Serializer']
def _reset_df_and_get_size(df): def _reset_df_and_get_size(df):
...@@ -87,8 +88,9 @@ class NumpySerializer(): ...@@ -87,8 +88,9 @@ class NumpySerializer():
""" """
Serialize the entire dataflow to a npz dict. Serialize the entire dataflow to a npz dict.
Note that this would have to store the entire dataflow in memory, Note that this would have to store the entire dataflow in memory,
and is also >10x slower than the other serializers. and is also >10x slower than LMDB/TFRecord serializers.
""" """
@staticmethod @staticmethod
def save(df, path): def save(df, path):
""" """
...@@ -102,7 +104,7 @@ class NumpySerializer(): ...@@ -102,7 +104,7 @@ class NumpySerializer():
for dp in df.get_data(): for dp in df.get_data():
buffer.append(dp) buffer.append(dp)
pbar.update() pbar.update()
np.savez_compressed(path, buffer=buffer) np.savez_compressed(path, buffer=np.asarray(buffer, dtype=np.object))
@staticmethod @staticmethod
def load(path, shuffle=True): def load(path, shuffle=True):
...@@ -152,6 +154,46 @@ class TFRecordSerializer(): ...@@ -152,6 +154,46 @@ class TFRecordSerializer():
return ds return ds
class HDF5Serializer():
"""
Write datapoints to a HDF5 file.
Note that HDF5 files are in fact not very performant and currently do not support lazy loading.
It's better to use :class:`LMDBSerializer`.
"""
@staticmethod
def save(df, path, data_paths):
"""
Args:
df (DataFlow): the DataFlow to serialize.
path (str): output hdf5 file.
data_paths (list[str]): list of h5 paths. It should have the same
length as each datapoint, and each path should correspond to one
component of the datapoint.
"""
size = _reset_df_and_get_size(df)
buffer = defaultdict(list)
with get_tqdm(total=size) as pbar:
for dp in df.get_data():
assert len(dp) == len(data_paths), "Datapoint has {} components!".format(len(dp))
for k, el in zip(data_paths, dp):
buffer[k].append(el)
pbar.update()
with h5py.File(path, 'w') as hf, get_tqdm(total=size) as pbar:
for data_path in data_paths:
hf.create_dataset(data_path, data=buffer[data_path])
@staticmethod
def load(path, data_paths, shuffle=True):
"""
Args:
data_paths (list): list of h5 paths to be zipped.
"""
return HDF5Data(path, data_paths, shuffle)
from ..utils.develop import create_dummy_class # noqa from ..utils.develop import create_dummy_class # noqa
try: try:
import lmdb import lmdb
...@@ -163,6 +205,11 @@ try: ...@@ -163,6 +205,11 @@ try:
except ImportError: except ImportError:
TFRecordSerializer = create_dummy_class('TFRecordSerializer', 'tensorflow') # noqa TFRecordSerializer = create_dummy_class('TFRecordSerializer', 'tensorflow') # noqa
try:
import h5py
except ImportError:
HDF5Serializer = create_dummy_class('HDF5Serializer', 'h5py') # noqa
if __name__ == '__main__': if __name__ == '__main__':
from .raw import FakeData from .raw import FakeData
...@@ -196,3 +243,12 @@ if __name__ == '__main__': ...@@ -196,3 +243,12 @@ if __name__ == '__main__':
pass pass
print("Numpy Finished, ", idx) print("Numpy Finished, ", idx)
print(time.time()) print(time.time())
HDF5Serializer.save(ds, 'out.h5')
print(time.time())
df = HDF5Serializer.load('out.h5')
df.reset_state()
for idx, dp in enumerate(df.get_data()):
pass
print("HDF5 Finished, ", idx)
print(time.time())
...@@ -11,6 +11,7 @@ import functools ...@@ -11,6 +11,7 @@ import functools
from datetime import datetime from datetime import datetime
import importlib import importlib
import types import types
import six
from . import logger from . import logger
...@@ -26,9 +27,18 @@ def create_dummy_class(klass, dependency): ...@@ -26,9 +27,18 @@ def create_dummy_class(klass, dependency):
Returns: Returns:
class: a class object class: a class object
""" """
class _DummyMetaClass(type):
# throw error on class attribute access
def __getattr__(_, __):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, klass))
@six.add_metaclass(_DummyMetaClass)
class _Dummy(object): class _Dummy(object):
# throw error on constructor
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, klass)) raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, klass))
return _Dummy return _Dummy
......
...@@ -8,8 +8,16 @@ export TF_CPP_MIN_LOG_LEVEL=2 ...@@ -8,8 +8,16 @@ export TF_CPP_MIN_LOG_LEVEL=2
# test import (#471) # test import (#471)
python -c 'from tensorpack.dataflow.imgaug import transform' python -c 'from tensorpack.dataflow.imgaug import transform'
python -m unittest discover -v # python -m unittest discover -v
# python -m tensorpack.models._test # python -m tensorpack.models._test
# segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985) # segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985)
# python ../tensorpack/user_ops/test-recv-op.py # python ../tensorpack/user_ops/test-recv-op.py
python test_char_rnn.py
python test_infogan.py
python test_mnist.py
python test_mnist_similarity.py
TENSORPACK_SERIALIZE=pyarrow python test_serializer.py
TENSORPACK_SERIALIZE=msgpack python test_serializer.py
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from tensorpack.dataflow.base import DataFlow
from tensorpack.dataflow import LMDBSerializer, TFRecordSerializer, NumpySerializer, HDF5Serializer
import unittest
import os
import numpy as np
def delete_file_if_exists(fn):
try:
os.remove(fn)
except OSError:
pass
class SeededFakeDataFlow(DataFlow):
"""docstring for SeededFakeDataFlow"""
def __init__(self, seed=42, size=32):
super(SeededFakeDataFlow, self).__init__()
self.seed = seed
self._size = size
self.cache = []
def reset_state(self):
np.random.seed(self.seed)
for _ in range(self._size):
label = np.random.randint(low=0, high=10)
img = np.random.randn(28, 28, 3)
self.cache.append([label, img])
def size(self):
return self._size
def get_data(self):
for dp in self.cache:
yield dp
class SerializerTest(unittest.TestCase):
def run_write_read_test(self, file, serializer, w_args, w_kwargs, r_args, r_kwargs, error_msg):
try:
delete_file_if_exists(file)
ds_expected = SeededFakeDataFlow()
serializer.save(ds_expected, file, *w_args, **w_kwargs)
ds_actual = serializer.load(file, *r_args, **r_kwargs)
ds_actual.reset_state()
ds_expected.reset_state()
for dp_expected, dp_actual in zip(ds_expected.get_data(), ds_actual.get_data()):
self.assertEqual(dp_expected[0], dp_actual[0])
self.assertTrue(np.allclose(dp_expected[1], dp_actual[1]))
except ImportError:
print(error_msg)
def test_lmdb(self):
self.run_write_read_test('test.lmdb', LMDBSerializer,
{}, {},
{}, {'shuffle': False},
'Skip test_lmdb, no lmdb available')
def test_tfrecord(self):
self.run_write_read_test('test.tfrecord', TFRecordSerializer,
{}, {},
{}, {'size': 32},
'Skip test_tfrecord, no tensorflow available')
def test_numpy(self):
self.run_write_read_test('test.npz', NumpySerializer,
{}, {},
{}, {'shuffle': False},
'Skip test_numpy, no numpy available')
def test_hdf5(self):
args = [['label', 'image']]
self.run_write_read_test('test.h5', HDF5Serializer,
args, {},
args, {'shuffle': False},
'Skip test_hdf5, no h5py available')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment