Commit e00ec36b authored by Yuxin Wu's avatar Yuxin Wu

Clean-up imports so that dataflow can be imported without TF.

parent 72c7684a
......@@ -38,8 +38,10 @@ install:
- pip install flake8 scikit-image opencv-python pypandoc
# here we use opencv-python, but users are in general not recommended to use this package,
# because it brings issues working with tensorflow on gpu
- ./tests/install-tensorflow.sh
- pip install .
# check that dataflow can be imported alone
- python -c "import tensorpack.dataflow"
- ./tests/install-tensorflow.sh
before_script:
- flake8 --version
......
......@@ -3,15 +3,18 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from tensorpack.libinfo import __version__
from tensorpack.libinfo import __version__, _HAS_TF
from tensorpack.utils import *
from tensorpack.models import *
from tensorpack.dataflow import *
from tensorpack.callbacks import *
from tensorpack.tfutils import *
# dataflow can be used alone without installing tensorflow
if _HAS_TF:
from tensorpack.models import *
from tensorpack.train import *
from tensorpack.graph_builder import *
from tensorpack.predict import *
from tensorpack.callbacks import *
from tensorpack.tfutils import *
from tensorpack.train import *
from tensorpack.graph_builder import *
from tensorpack.predict import *
......@@ -8,7 +8,7 @@ from collections import deque
from .base import DataFlow, DataFlowReentrantGuard
from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.serialize import dumps, loads, dumps_for_tfop
from ..utils.serialize import dumps, loads
try:
import zmq
except ImportError:
......@@ -30,7 +30,8 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None):
"""
# format (str): The serialization format. ZMQ Op is still not publicly usable now
# Default format would use :mod:`tensorpack.utils.serialize`.
dump_fn = dumps if format is None else dumps_for_tfop
# dump_fn = dumps if format is None else dumps_for_tfop
dump_fn = dumps
ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH)
socket.set_hwm(hwm)
......
......@@ -25,7 +25,12 @@ os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # issue#9339
os.environ['TF_AUTOTUNE_THRESHOLD'] = '3' # use more warm-up
os.environ['TF_AVGPOOL_USE_CUDNN'] = '1' # issue#8566
import tensorflow as tf # noqa
assert int(tf.__version__.split('.')[0]) >= 1, "TF>=1.0 is required!"
try:
import tensorflow as tf # noqa
assert int(tf.__version__.split('.')[0]) >= 1, "TF>=1.0 is required!"
_HAS_TF = True
except ImportError:
_HAS_TF = False
__version__ = '0.4.0'
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from __future__ import print_function
import tensorflow as tf
import os
__all__ = ['zmq_recv']
include_dir = tf.sysconfig.get_include()
file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'INCLUDE_DIR="-isystem {}" make -C "{}"'.format(include_dir, file_dir)
print("Compiling user ops ...")
ret = os.system(compile_cmd)
if ret != 0:
print("Failed to compile user ops!")
zmq_recv = None
else:
recv_mod = tf.load_op_library(os.path.join(file_dir, 'zmq_recv_op.so'))
# TODO trigger recompile when load fails
zmq_recv = recv_mod.zmq_recv
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
from __future__ import print_function
import tensorflow as tf
import os
def compile():
# TODO check modtime?
include_dir = tf.sysconfig.get_include()
file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'INCLUDE_DIR="-isystem {}" make -C "{}"'.format(include_dir, file_dir)
ret = os.system(compile_cmd)
return ret
if __name__ == '__main__':
compile()
......@@ -10,8 +10,8 @@ import multiprocessing as mp
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa
from tensorpack.utils.serialize import dump_tensor_protos, to_tensor_proto # noqa
from tensorpack.user_ops import zmq_recv # noqa
from tensorpack.user_ops.zmq_recv import ( # noqa
zmq_recv, dump_tensor_protos, to_tensor_proto)
try:
num = int(sys.argv[1])
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: zmq_recv.py
import tensorflow as tf
import struct
import numpy as np
import os
from tensorflow.core.framework.tensor_pb2 import TensorProto
from tensorflow.core.framework import types_pb2 as DataType
# have to import like this: https://github.com/tensorflow/tensorflow/commit/955f038afbeb81302cea43058078e68574000bce
from .common import compile
__all__ = ['zmq_recv', 'dumps_for_tfop',
'dump_tensor_protos', 'to_tensor_proto']
def build():
global zmq_recv
ret = compile()
if ret != 0:
zmq_recv = None
else:
file_dir = os.path.dirname(os.path.abspath(__file__))
recv_mod = tf.load_op_library(
os.path.join(file_dir, 'zmq_recv_op.so'))
zmq_recv = recv_mod.zmq_recv
build()
_DTYPE_DICT = {
np.float32: DataType.DT_FLOAT,
np.float64: DataType.DT_DOUBLE,
np.int32: DataType.DT_INT32,
np.int8: DataType.DT_INT8,
np.uint8: DataType.DT_UINT8,
}
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
# TODO support string tensor and scalar
def to_tensor_proto(arr):
"""
Convert a numpy array to TensorProto
Args:
arr: numpy.ndarray. only supports common numerical types
"""
dtype = _DTYPE_DICT[arr.dtype]
ret = TensorProto()
shape = ret.tensor_shape
for s in arr.shape:
d = shape.dim.add()
d.size = s
ret.dtype = dtype
buf = arr.tobytes()
ret.tensor_content = buf
return ret
def dump_tensor_protos(protos):
"""
Serialize a list of :class:`TensorProto`, for communication between custom TensorFlow ops.
Args:
protos (list): list of :class:`TensorProto` instance
Notes:
The format is:
[#tensors(int32)]
[tensor1][tensor2]...
Where each tensor is:
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[len(buffer)(int32)][buffer]
"""
# TODO use int64
s = struct.pack('=i', len(protos))
for p in protos:
tensor_content = p.tensor_content
s += struct.pack('=i', int(p.dtype))
dims = p.tensor_shape.dim
s += struct.pack('=i', len(dims))
for k in dims:
s += struct.pack('=i', k.size)
s += struct.pack('=i', len(tensor_content)) # won't send stuff over 2G
s += tensor_content
return s
def dumps_for_tfop(dp):
"""
Dump a datapoint (list of nparray) into a format that the ZMQRecv op in tensorpack would accept.
"""
protos = [to_tensor_proto(arr) for arr in dp]
return dump_tensor_protos(protos)
......@@ -6,17 +6,10 @@
import msgpack
import msgpack_numpy
import struct
import numpy as np
from tensorflow.core.framework.tensor_pb2 import TensorProto
from tensorflow.core.framework import types_pb2 as DataType
# have to import like this: https://github.com/tensorflow/tensorflow/commit/955f038afbeb81302cea43058078e68574000bce
msgpack_numpy.patch()
__all__ = ['loads', 'dumps', 'dumps_for_tfop', 'dump_tensor_protos',
'to_tensor_proto']
__all__ = ['loads', 'dumps']
def dumps(obj):
......@@ -35,75 +28,3 @@ def loads(buf):
buf (str): serialized object.
"""
return msgpack.loads(buf)
_DTYPE_DICT = {
np.float32: DataType.DT_FLOAT,
np.float64: DataType.DT_DOUBLE,
np.int32: DataType.DT_INT32,
np.int8: DataType.DT_INT8,
np.uint8: DataType.DT_UINT8,
}
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
# TODO support string tensor and scalar
def to_tensor_proto(arr):
"""
Convert a numpy array to TensorProto
Args:
arr: numpy.ndarray. only supports common numerical types
"""
dtype = _DTYPE_DICT[arr.dtype]
ret = TensorProto()
shape = ret.tensor_shape
for s in arr.shape:
d = shape.dim.add()
d.size = s
ret.dtype = dtype
buf = arr.tobytes()
ret.tensor_content = buf
return ret
def dump_tensor_protos(protos):
"""
Serialize a list of :class:`TensorProto`, for communication between custom TensorFlow ops.
Args:
protos (list): list of :class:`TensorProto` instance
Notes:
The format is:
[#tensors(int32)]
[tensor1][tensor2]...
Where each tensor is:
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[len(buffer)(int32)][buffer]
"""
# TODO use int64
s = struct.pack('=i', len(protos))
for p in protos:
tensor_content = p.tensor_content
s += struct.pack('=i', int(p.dtype))
dims = p.tensor_shape.dim
s += struct.pack('=i', len(dims))
for k in dims:
s += struct.pack('=i', k.size)
s += struct.pack('=i', len(tensor_content)) # won't send stuff over 2G
s += tensor_content
return s
def dumps_for_tfop(dp):
protos = [to_tensor_proto(arr) for arr in dp]
return dump_tensor_protos(protos)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment