Commit cabf6af1 authored by Yuxin Wu's avatar Yuxin Wu

fix h5py tests

parent a83a1824
......@@ -64,7 +64,7 @@ setup(
],
tests_require=['flake8', 'scikit-image'],
extras_require={
'all': ['scipy', 'h5py', 'lmdb>=0.92', 'matplotlib', 'scikit-learn'],
'all': ['scipy', 'h5py>=2.1', 'lmdb>=0.92', 'matplotlib', 'scikit-learn'],
'all: "linux" in sys_platform': ['python-prctl'],
},
......
......@@ -30,7 +30,7 @@ def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
if lst:
del globals()[name]
globals().pop(name, None)
for k in lst:
if not k.startswith('__'):
globals()[k] = p.__dict__[k]
......
......@@ -43,7 +43,7 @@ class HDF5Data(RNGDataFlow):
"""
self.f = h5py.File(filename, 'r')
logger.info("Loading {} to memory...".format(filename))
self.dps = [self.f[k].value for k in data_paths]
self.dps = [self.f[k][...] for k in data_paths]
lens = [len(k) for k in self.dps]
assert all(k == lens[0] for k in lens)
self._size = lens[0]
......
......@@ -231,7 +231,7 @@ except ImportError:
LMDBSerializer = create_dummy_class('LMDBSerializer', 'lmdb') # noqa
try:
import tensorflow as tf
from tensorpack.compat import tfv1 as tf
except ImportError:
TFRecordSerializer = create_dummy_class('TFRecordSerializer', 'tensorflow') # noqa
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment