Commit b5ac2443 authored by Yuxin Wu's avatar Yuxin Wu

Use pyarrow instead of msgpack.

parent 8b4d4f77
...@@ -8,6 +8,8 @@ so you won't need to look at here very often. ...@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here. TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2018/04/05] msgpack is replaced by pyarrow. If you want compatibility with old serialized data,
manually uninstall pyarrow, and msgpack will be used as a fallback.
+ [2018/03/20] `ModelDesc` starts to use simplified interfaces: + [2018/03/20] `ModelDesc` starts to use simplified interfaces:
+ `_get_inputs()` renamed to `inputs()` and returns `tf.placeholder`s. + `_get_inputs()` renamed to `inputs()` and returns `tf.placeholder`s.
+ `build_graph(self, tensor1, tensor2)` returns the cost tensor directly. + `build_graph(self, tensor1, tensor2)` returns the cost tensor directly.
......
...@@ -25,13 +25,13 @@ ON_RTD = (os.environ.get('READTHEDOCS') == 'True') ...@@ -25,13 +25,13 @@ ON_RTD = (os.environ.get('READTHEDOCS') == 'True')
MOCK_MODULES = ['tabulate', 'h5py', MOCK_MODULES = ['tabulate', 'h5py',
'cv2', 'zmq', 'subprocess32', 'lmdb', 'cv2', 'zmq', 'lmdb',
'sklearn', 'sklearn.datasets', 'sklearn', 'sklearn.datasets',
'scipy', 'scipy.misc', 'scipy.io', 'scipy', 'scipy.misc', 'scipy.io',
'tornado', 'tornado.concurrent', 'tornado', 'tornado.concurrent',
'horovod', 'horovod.tensorflow', 'horovod', 'horovod.tensorflow',
'pyarrow', 'msgpack', 'msgpack_numpy', 'pyarrow',
'functools32'] 'subprocess32', 'functools32']
for mod_name in MOCK_MODULES: for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name) sys.modules[mod_name] = mock.Mock(name=mod_name)
sys.modules['cv2'].__version__ = '3.2.1' # fake version sys.modules['cv2'].__version__ = '3.2.1' # fake version
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
import sys
import os import os
import uuid import uuid
import argparse import argparse
......
...@@ -3,8 +3,7 @@ six ...@@ -3,8 +3,7 @@ six
termcolor>=1.1 termcolor>=1.1
tabulate>=0.7.7 tabulate>=0.7.7
tqdm>4.11.1 tqdm>4.11.1
msgpack>=0.5.2 pyarrow>=0.9.0
msgpack-numpy>=0.4.0
pyzmq>=16 pyzmq>=16
subprocess32; python_version < '3.0' subprocess32; python_version < '3.0'
functools32; python_version < '3.0' functools32; python_version < '3.0'
...@@ -106,8 +106,17 @@ class ILSVRCMeta(object): ...@@ -106,8 +106,17 @@ class ILSVRCMeta(object):
arr = cv2.resize(arr, size[::-1]) arr = cv2.resize(arr, size[::-1])
return arr return arr
@staticmethod
def guess_dir_structure(dir):
"""
Return the directory structure of "dir".
Args:
dir(str): something like '/path/to/imagenet/val'
def _guess_dir_structure(dir): Returns:
either 'train' or 'original'
"""
subdir = os.listdir(dir)[0] subdir = os.listdir(dir)[0]
# find a subdir starting with 'n' # find a subdir starting with 'n'
if subdir.startswith('n') and \ if subdir.startswith('n') and \
...@@ -145,7 +154,7 @@ class ILSVRC12Files(RNGDataFlow): ...@@ -145,7 +154,7 @@ class ILSVRC12Files(RNGDataFlow):
if name == 'train': if name == 'train':
dir_structure = 'train' dir_structure = 'train'
if dir_structure is None: if dir_structure is None:
dir_structure = _guess_dir_structure(self.full_dir) dir_structure = ILSVRCMeta.guess_dir_structure(self.full_dir)
meta = ILSVRCMeta(meta_dir) meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name, dir_structure) self.imglist = meta.get_image_list(name, dir_structure)
......
...@@ -400,6 +400,8 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -400,6 +400,8 @@ class MultiThreadPrefetchData(DataFlow):
class PlasmaPutData(ProxyDataFlow): class PlasmaPutData(ProxyDataFlow):
""" """
Put each data point to plasma shared memory object store, and yield the object id instead. Put each data point to plasma shared memory object store, and yield the object id instead.
Experimental.
""" """
def __init__(self, ds): def __init__(self, ds):
super(PlasmaPutData, self).__init__(ds) super(PlasmaPutData, self).__init__(ds)
......
...@@ -98,7 +98,8 @@ def set_logger_dir(dirname, action=None): ...@@ -98,7 +98,8 @@ def set_logger_dir(dirname, action=None):
# unload and close the old file handler, so that we may safely delete the logger directory # unload and close the old file handler, so that we may safely delete the logger directory
_logger.removeHandler(_FILE_HANDLER) _logger.removeHandler(_FILE_HANDLER)
del _FILE_HANDLER del _FILE_HANDLER
if os.path.isdir(dirname) and len(os.listdir(dirname)): # If directory exists and nonempty (ignore hidden files), prompt for action
if os.path.isdir(dirname) and len([x for x in os.listdir(dirname) if x[0] != '.']):
if not action: if not action:
_logger.warn("""\ _logger.warn("""\
Log directory {} exists! Use 'd' to delete it. """.format(dirname)) Log directory {} exists! Use 'd' to delete it. """.format(dirname))
......
...@@ -2,23 +2,7 @@ ...@@ -2,23 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: serialize.py # File: serialize.py
import sys from .develop import create_dummy_func
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
# https://github.com/apache/arrow/pull/1223#issuecomment-359895666
old_mod = sys.modules.get('torch', None)
sys.modules['torch'] = None
try:
import pyarrow as pa
except ImportError:
pa = None
if old_mod is not None:
sys.modules['torch'] = old_mod
else:
del sys.modules['torch']
__all__ = ['loads', 'dumps'] __all__ = ['loads', 'dumps']
...@@ -58,6 +42,25 @@ def loads_pyarrow(buf): ...@@ -58,6 +42,25 @@ def loads_pyarrow(buf):
return pa.deserialize(buf) return pa.deserialize(buf)
try:
# fixed in pyarrow 0.9: https://github.com/apache/arrow/pull/1223#issuecomment-359895666
import pyarrow as pa
except ImportError:
pa = None
dumps_pyarrow = create_dummy_func('dumps_pyarrow', ['pyarrow']) # noqa
loads_pyarrow = create_dummy_func('loads_pyarrow', ['pyarrow']) # noqa
try:
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
except ImportError:
assert pa is not None, "pyarrow is a dependency of tensorpack!"
loads_msgpack = create_dummy_func( # noqa
'loads_msgpack', ['msgpack', 'msgpack_numpy'])
dumps_msgpack = create_dummy_func( # noqa
'dumps_msgpack', ['msgpack', 'msgpack_numpy'])
if pa is None: if pa is None:
loads = loads_msgpack loads = loads_msgpack
dumps = dumps_msgpack dumps = dumps_msgpack
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment