Commit b5ac2443 authored by Yuxin Wu's avatar Yuxin Wu

Use pyarrow instead of msgpack.

parent 8b4d4f77
......@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2018/04/05] msgpack is replaced by pyarrow. If you want compatibility with old serialized data,
manually uninstall pyarrow, and msgpack will be used as a fallback.
+ [2018/03/20] `ModelDesc` starts to use simplified interfaces:
+ `_get_inputs()` renamed to `inputs()` and returns `tf.placeholder`s.
+ `build_graph(self, tensor1, tensor2)` returns the cost tensor directly.
......
......@@ -25,13 +25,13 @@ ON_RTD = (os.environ.get('READTHEDOCS') == 'True')
MOCK_MODULES = ['tabulate', 'h5py',
'cv2', 'zmq', 'subprocess32', 'lmdb',
'cv2', 'zmq', 'lmdb',
'sklearn', 'sklearn.datasets',
'scipy', 'scipy.misc', 'scipy.io',
'tornado', 'tornado.concurrent',
'horovod', 'horovod.tensorflow',
'pyarrow', 'msgpack', 'msgpack_numpy',
'functools32']
'pyarrow',
'subprocess32', 'functools32']
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name)
sys.modules['cv2'].__version__ = '3.2.1' # fake version
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import sys
import os
import uuid
import argparse
......
......@@ -3,8 +3,7 @@ six
termcolor>=1.1
tabulate>=0.7.7
tqdm>4.11.1
msgpack>=0.5.2
msgpack-numpy>=0.4.0
pyarrow>=0.9.0
pyzmq>=16
subprocess32; python_version < '3.0'
functools32; python_version < '3.0'
......@@ -106,19 +106,28 @@ class ILSVRCMeta(object):
arr = cv2.resize(arr, size[::-1])
return arr
@staticmethod
def guess_dir_structure(dir):
"""
Return the directory structure of "dir".
Args:
dir(str): something like '/path/to/imagenet/val'
def _guess_dir_structure(dir):
subdir = os.listdir(dir)[0]
# find a subdir starting with 'n'
if subdir.startswith('n') and \
os.path.isdir(os.path.join(dir, subdir)):
dir_structure = 'train'
else:
dir_structure = 'original'
logger.info(
"[ILSVRC12] Assuming directory {} has '{}' structure.".format(
dir, dir_structure))
return dir_structure
Returns:
either 'train' or 'original'
"""
subdir = os.listdir(dir)[0]
# find a subdir starting with 'n'
if subdir.startswith('n') and \
os.path.isdir(os.path.join(dir, subdir)):
dir_structure = 'train'
else:
dir_structure = 'original'
logger.info(
"[ILSVRC12] Assuming directory {} has '{}' structure.".format(
dir, dir_structure))
return dir_structure
class ILSVRC12Files(RNGDataFlow):
......@@ -145,7 +154,7 @@ class ILSVRC12Files(RNGDataFlow):
if name == 'train':
dir_structure = 'train'
if dir_structure is None:
dir_structure = _guess_dir_structure(self.full_dir)
dir_structure = ILSVRCMeta.guess_dir_structure(self.full_dir)
meta = ILSVRCMeta(meta_dir)
self.imglist = meta.get_image_list(name, dir_structure)
......
......@@ -400,6 +400,8 @@ class MultiThreadPrefetchData(DataFlow):
class PlasmaPutData(ProxyDataFlow):
"""
Put each data point to plasma shared memory object store, and yield the object id instead.
Experimental.
"""
def __init__(self, ds):
super(PlasmaPutData, self).__init__(ds)
......
......@@ -98,7 +98,8 @@ def set_logger_dir(dirname, action=None):
# unload and close the old file handler, so that we may safely delete the logger directory
_logger.removeHandler(_FILE_HANDLER)
del _FILE_HANDLER
if os.path.isdir(dirname) and len(os.listdir(dirname)):
# If directory exists and nonempty (ignore hidden files), prompt for action
if os.path.isdir(dirname) and len([x for x in os.listdir(dirname) if x[0] != '.']):
if not action:
_logger.warn("""\
Log directory {} exists! Use 'd' to delete it. """.format(dirname))
......
......@@ -2,23 +2,7 @@
# -*- coding: utf-8 -*-
# File: serialize.py
import sys
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
# https://github.com/apache/arrow/pull/1223#issuecomment-359895666
old_mod = sys.modules.get('torch', None)
sys.modules['torch'] = None
try:
import pyarrow as pa
except ImportError:
pa = None
if old_mod is not None:
sys.modules['torch'] = old_mod
else:
del sys.modules['torch']
from .develop import create_dummy_func
__all__ = ['loads', 'dumps']
......@@ -58,6 +42,25 @@ def loads_pyarrow(buf):
return pa.deserialize(buf)
try:
# fixed in pyarrow 0.9: https://github.com/apache/arrow/pull/1223#issuecomment-359895666
import pyarrow as pa
except ImportError:
pa = None
dumps_pyarrow = create_dummy_func('dumps_pyarrow', ['pyarrow']) # noqa
loads_pyarrow = create_dummy_func('loads_pyarrow', ['pyarrow']) # noqa
try:
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
except ImportError:
assert pa is not None, "pyarrow is a dependency of tensorpack!"
loads_msgpack = create_dummy_func( # noqa
'loads_msgpack', ['msgpack', 'msgpack_numpy'])
dumps_msgpack = create_dummy_func( # noqa
'dumps_msgpack', ['msgpack', 'msgpack_numpy'])
if pa is None:
loads = loads_msgpack
dumps = dumps_msgpack
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment