Commit cabf6af1 authored by Yuxin Wu's avatar Yuxin Wu

fix h5py tests

parent a83a1824
...@@ -64,7 +64,7 @@ setup( ...@@ -64,7 +64,7 @@ setup(
], ],
tests_require=['flake8', 'scikit-image'], tests_require=['flake8', 'scikit-image'],
extras_require={ extras_require={
'all': ['scipy', 'h5py', 'lmdb>=0.92', 'matplotlib', 'scikit-learn'], 'all': ['scipy', 'h5py>=2.1', 'lmdb>=0.92', 'matplotlib', 'scikit-learn'],
'all: "linux" in sys_platform': ['python-prctl'], 'all: "linux" in sys_platform': ['python-prctl'],
}, },
......
...@@ -30,7 +30,7 @@ def _global_import(name): ...@@ -30,7 +30,7 @@ def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
if lst: if lst:
del globals()[name] globals().pop(name, None)
for k in lst: for k in lst:
if not k.startswith('__'): if not k.startswith('__'):
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
......
...@@ -43,7 +43,7 @@ class HDF5Data(RNGDataFlow): ...@@ -43,7 +43,7 @@ class HDF5Data(RNGDataFlow):
""" """
self.f = h5py.File(filename, 'r') self.f = h5py.File(filename, 'r')
logger.info("Loading {} to memory...".format(filename)) logger.info("Loading {} to memory...".format(filename))
self.dps = [self.f[k].value for k in data_paths] self.dps = [self.f[k][...] for k in data_paths]
lens = [len(k) for k in self.dps] lens = [len(k) for k in self.dps]
assert all(k == lens[0] for k in lens) assert all(k == lens[0] for k in lens)
self._size = lens[0] self._size = lens[0]
......
...@@ -231,7 +231,7 @@ except ImportError: ...@@ -231,7 +231,7 @@ except ImportError:
LMDBSerializer = create_dummy_class('LMDBSerializer', 'lmdb') # noqa LMDBSerializer = create_dummy_class('LMDBSerializer', 'lmdb') # noqa
try: try:
import tensorflow as tf from tensorpack.compat import tfv1 as tf
except ImportError: except ImportError:
TFRecordSerializer = create_dummy_class('TFRecordSerializer', 'tensorflow') # noqa TFRecordSerializer = create_dummy_class('TFRecordSerializer', 'tensorflow') # noqa
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment