Commit 0c67af01 authored by Yuxin Wu's avatar Yuxin Wu

Move zmq ops to a separate project

parent b20f615d
......@@ -33,13 +33,13 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None):
hwm (int): ZMQ high-water mark (buffer size)
format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize` (i.e. msgpack).
An alternate format is 'zmq_op'.
An alternate format is 'zmq_op', used by https://github.com/tensorpack/zmq_ops.
"""
assert format in [None, 'zmq_op']
if format is None:
dump_fn = dumps
else:
from ..user_ops.zmq_recv import dumps_zmq_op
from zmq_ops import dumps_zmq_op
dump_fn = dumps_zmq_op
ctx = zmq.Context()
......
......@@ -370,7 +370,7 @@ class DummyConstantInput(TensorInput):
class ZMQInput(TensorInput):
"""
Recv tensors from a ZMQ endpoint.
Recv tensors from a ZMQ endpoint, with ops from https://github.com/tensorpack/zmq_ops.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op')`.
"""
def __init__(self, end_point, hwm):
......@@ -395,7 +395,7 @@ class ZMQInput(TensorInput):
"ZMQInput has to be used with InputDesc!"
self._desc = inputs_desc
from ..user_ops import zmq_ops
import zmq_ops
self._zmq_pull_socket = zmq_ops.ZMQPullSocket(
self._end_point,
[x.type for x in inputs_desc],
......
import tensorflow as tf
flags = [
'-Wall',
'-Wextra',
'-Werror',
'-Wno-long-long',
'-Wno-variadic-macros',
'-fexceptions',
'-std=c++11',
'-x',
'c++',
'-isystem',
tf.sysconfig.get_include()
]
def FlagsForFile(filename, **kwargs):
return {
'flags': flags,
'do_cache': True
}
# $File: Makefile
# $Date: Thu Dec 21 14:12:30 2017 -0800
OBJ_DIR = obj
PYTHON = python
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Linux)
CXX ?= g++
endif
ifeq ($(UNAME_S),Darwin)
CXX ?= clang++
endif
OPTFLAGS ?= -O3 -march=native
#OPTFLAGS ?= -g3 -fsanitize=address,undefined -O2 -lasan
#OPTFLAGS ?= -g3 -fsanitize=leak -O2 -lubsan
# libraries: TF preceeds others, so g++ looks for protobuf among TF headers
ifneq ($(MAKECMDGOALS), clean)
TF_CXXFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
TF_LDFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
endif
CXXFLAGS += $(TF_CXXFLAGS)
LDFLAGS += $(TF_LDFLAGS)
# extra packages from pkg-config
LIBS = libzmq
CXXFLAGS += $(shell pkg-config --cflags $(LIBS))
LDFLAGS += $(shell pkg-config $(LIBS) --libs)
CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare
CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC
LDFLAGS += $(OPTFLAGS)
LDFLAGS += -shared -fPIC
ifeq ($(UNAME_S),Darwin)
LDFLAGS += -Wl,-undefined -Wl,dynamic_lookup
endif
SHELL = bash
# sources to include
ccSOURCES = $(shell find $(SRCDIRS) -name "*.cc" | sed 's/^\.\///g')
OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES = $(OBJS:.o=.d)
EXT_SUFFIX ?= $(shell $(PYTHON) -c 'import sysconfig; print(sysconfig.get_config_var("EXT_SUFFIX"))')
SO = zmq_ops$(EXT_SUFFIX)
.PHONY: all clean
all: $(SO)
ifneq ($(MAKECMDGOALS), clean)
sinclude $(DEPFILES)
endif
%$(EXT_SUFFIX): $(OBJ_DIR)/%.o
@echo "Linking $@ ..."
@$(CXX) $^ -o $@ $(LDFLAGS)
@echo "done."
$(OBJ_DIR)/%.o: %.cc
@echo "[cc] $< ..."
@$(CXX) -c $< -o $@ $(CXXFLAGS)
$(OBJ_DIR)/%.d: %.cc Makefile
@mkdir -pv $(dir $@)
@echo "[dep] $< ..."
@$(CXX) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cc=.o) $(OBJ_DIR)/$(<:.cc=.d)" "$<" > "$@" || rm "$@"
clean:
@rm -rvf $(OBJ_DIR) $(SO)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
import sysconfig
import tensorflow as tf
import os
from ..utils import logger
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
def get_ext_suffix():
"""Determine library extension for various versions of Python."""
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
if ext_suffix:
return ext_suffix
ext_suffix = sysconfig.get_config_var('SO')
if ext_suffix:
return ext_suffix
return '.so'
def compile():
cxxflags = ' '.join(tf.sysconfig.get_compile_flags())
ldflags = ' '.join(tf.sysconfig.get_link_flags())
ext_suffix = get_ext_suffix()
file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" EXT_SUFFIX="{}" make -C "{}"'.format(
cxxflags, ldflags, ext_suffix, file_dir)
logger.info("Compile user_ops by command " + compile_cmd + ' ...')
ret = os.system(compile_cmd)
return ret
if __name__ == '__main__':
compile()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: test-pull-op.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import zmq
import argparse
import multiprocessing as mp
import time
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa
from tensorpack.user_ops.zmq_ops import ( # noqa
ZMQPullSocket, dumps_zmq_op)
from tensorpack.utils.concurrency import ( # noqa
start_proc_mask_signal,
ensure_proc_terminate)
ENDPOINT = 'ipc://test-pipe'
def send(iterable, delay=0):
ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH)
sok.connect(ENDPOINT)
for dp in iterable:
if delay > 0:
time.sleep(delay)
print("Sending data to socket..")
sok.send(dumps_zmq_op(dp))
time.sleep(999)
def random_array(num):
ret = []
for k in range(num):
arr1 = np.random.rand(k + 10, k + 10).astype('float32')
# arr1 = 3.0
arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
ret.append([arr1, arr2])
return ret
def constant_array(num):
arr = np.ones((30, 30)).astype('float32')
arr2 = np.ones((3, 3)).astype('uint8')
return [[arr, arr2]] * num
def hash_dp(dp):
return sum([k.sum() for k in dp])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', default='basic',
choices=['basic', 'tworecv', 'send'])
parser.add_argument('-n', '--num', type=int, default=10)
args = parser.parse_args()
if args.task == 'basic':
DATA = random_array(args.num)
p = mp.Process(target=send, args=(DATA,))
ensure_proc_terminate(p)
start_proc_mask_signal(p)
sess = tf.Session()
recv = ZMQPullSocket(ENDPOINT, [tf.float32, tf.uint8]).pull()
print(recv)
for truth in DATA:
arr = sess.run(recv)
assert (arr[0] == truth[0]).all()
assert (arr[1] == truth[1]).all()
elif args.task == 'send':
DATA = random_array(args.num)
send(DATA)
elif args.task == 'tworecv':
DATA = random_array(args.num)
hashes = [hash_dp(dp) for dp in DATA]
print(hashes)
p = mp.Process(target=send, args=(DATA, 0.00))
ensure_proc_terminate(p)
start_proc_mask_signal(p)
sess = tf.Session()
zmqsock = ZMQPullSocket(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv1 = zmqsock.pull()
recv2 = zmqsock.pull()
print(recv1, recv2)
for i in range(args.num // 2):
res1, res2 = sess.run([recv1, recv2])
h1, h2 = hash_dp(res1), hash_dp(res2)
print("Recv ", i, h1, h2)
assert h1 in hashes and h2 in hashes
//File: zmq_conn.h
//Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#pragma once
#include <string>
#include <iostream>
#include <thread>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/lib/strings/strcat.h>
#include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp"
namespace {
inline int read_int32(char** p) {
auto pi = reinterpret_cast<const int*>(*p);
*p += 4;
return *pi;
}
inline tensorflow::int64 read_int64(char** p) {
auto pi = reinterpret_cast<const tensorflow::int64*>(*p);
*p += 8;
return *pi;
}
}
namespace tensorpack {
struct ZMQSocketDef {
std::string end_point;
int socket_type, // ZMQ_PULL
hwm;
bool bind; // bind or connect
std::string DebugString() const {
return tensorflow::strings::StrCat("EndPoint=", end_point, ", hwm=", std::to_string(hwm));
}
};
struct RecvTensorList {
zmq::message_t message;
struct TensorConstructor {
tensorflow::DataType dtype;
tensorflow::TensorShape shape;
tensorflow::int64 buf_size;
char* buf;
};
tensorflow::gtl::InlinedVector<TensorConstructor, 4> tensors;
};
class ZMQConnection : public tensorflow::ResourceBase {
public:
explicit ZMQConnection(const ZMQSocketDef& def):
def_{def}, ctx_{1}, sock_{ctx_, def.socket_type} {
int linger = 0;
sock_.setsockopt(ZMQ_LINGER, &linger , sizeof linger);
sock_.setsockopt(ZMQ_RCVHWM, &def.hwm , sizeof def.hwm);
if (def.bind) {
sock_.bind(def.end_point.c_str());
} else {
sock_.connect(def.end_point.c_str());
}
}
std::string DebugString() override { return def_.DebugString(); }
void recv_tensor_list(RecvTensorList* tlist) {
{
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
// zmq socket is not thread safe
tensorflow::mutex_lock lk(mu_);
bool succ = sock_.recv(&tlist->message); // block until some data appears
// TODO this may throw, handle exception?
// Possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
// succ=false only if EAGAIN
CHECK(succ); // no EAGAIN, because we are blocking
}
char* pos = reinterpret_cast<char*>(tlist->message.data());
int num = read_int32(&pos);
auto& tensors = tlist->tensors;
tensors.resize(num);
CHECK_LE(num, 15); // probably a format error
for (int i = 0; i < num; ++i) {
int dt = read_int32(&pos);
tensors[i].dtype = tensorflow::DataType(dt);
int ndim = read_int32(&pos);
CHECK_LE(ndim, 8); // probably an error.
for (int k = 0; k < ndim; ++k) {
int shp = read_int32(&pos);
tensors[i].shape.AddDim(shp);
}
tensorflow::int64 sz = read_int64(&pos);
tensors[i].buf = pos;
tensors[i].buf_size = sz;
pos += sz;
}
}
const ZMQSocketDef& get_socket_def() const { return def_; }
private:
ZMQSocketDef def_;
tensorflow::mutex mu_;
zmq::context_t ctx_;
zmq::socket_t sock_;
};
} // namespace tensorpack
//File: zmq_ops.cc
//Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#include <string>
#include <memory>
#include <tensorflow/core/framework/op.h>
#include <tensorflow/core/framework/op_kernel.h>
#include <tensorflow/core/framework/resource_op_kernel.h>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/common_shape_fns.h>
#include "zmq_conn.h"
using namespace std;
using namespace tensorflow;
namespace tensorpack {
// An op to create zmq connection as a resource.
// Use ResourceOpKernel to ensure singleton construction.
class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> {
public:
explicit ZMQConnectionHandleOp(OpKernelConstruction* ctx)
: ResourceOpKernel<ZMQConnection>(ctx) {}
private:
Status CreateResource(ZMQConnection** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const NodeDef& ndef = def();
ZMQSocketDef sockdef;
sockdef.socket_type = ZMQ_PULL;
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "bind", &sockdef.bind));
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "end_point", &sockdef.end_point));
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "hwm", &sockdef.hwm));
*ret = new ZMQConnection(sockdef);
return Status::OK();
}
// Can verify, but probably not necessary because python is not going to eval this op twice with
// the same shared name
};
class ZMQPullOp: public AsyncOpKernel {
public:
explicit ZMQPullOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
ZMQConnection* conn = nullptr;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &conn), done);
RecvTensorList tlist;
conn->recv_tensor_list(&tlist);
auto& tensors = tlist.tensors;
CHECK(tensors.size() == num_components());
for (int i = 0; i < tensors.size(); ++i) {
Tensor* output = nullptr;
auto recv_dtype = tensors[i].dtype;
OP_REQUIRES_ASYNC(
ctx, component_types_[i] == recv_dtype,
errors::InvalidArgument("Type mismatch at index ", std::to_string(i),
" between received tensor (", DataTypeString(recv_dtype),
") and dtype (", DataTypeString(component_types_[i]), ")"),
done);
TensorShape& shape = tensors[i].shape;
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done);
// reinterpret cast and then memcpy
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}).data();
// {shape.num_elements() * DataTypeSize(recv_dtype)}).data();
memcpy(ptr, tensors[i].buf, tensors[i].buf_size);
}
done();
}
private:
DataTypeVector component_types_;
size_t num_components() const { return component_types_.size(); }
};
REGISTER_KERNEL_BUILDER(Name("ZMQPull").Device(DEVICE_CPU), ZMQPullOp);
REGISTER_KERNEL_BUILDER(Name("ZMQConnection").Device(DEVICE_CPU), ZMQConnectionHandleOp);
} // namespace tensorpack
REGISTER_OP("ZMQPull")
.Input("handle: resource")
.Output("output: types")
.Attr("types: list(type) >= 1")
.SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful()
.Doc(R"doc(
Receive a list of Tensors from a ZMQ connection handle.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc");
REGISTER_OP("ZMQConnection")
.Output("handle: resource")
.Attr("end_point: string")
.Attr("hwm: int >= 1 = 10")
.Attr("bind: bool = true")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Opens a ZMQ PULL socket and returns a handle to it as a resource.
end_point: the ZMQ end point.
hwm: ZMQ high-water mark.
bind: If false, will connect to the endpoint rather than bind to it.
container: required for a resource op kernel.
shared_name: If non-empty, this connection will be shared under the given name across multiple sessions.
)doc");
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: zmq_pull.py
import tensorflow as tf
import struct
import numpy as np
import os
from tensorflow.core.framework.tensor_pb2 import TensorProto
from tensorflow.core.framework import types_pb2 as DT
# have to import like this: https://github.com/tensorflow/tensorflow/commit/955f038afbeb81302cea43058078e68574000bce
from .common import compile, get_ext_suffix
__all__ = ['dumps_zmq_op', 'ZMQPullSocket']
_zmq_mod = None
def try_build():
file_dir = os.path.dirname(os.path.abspath(__file__))
basename = 'zmq_ops' + get_ext_suffix()
so_file = os.path.join(file_dir, basename)
if not os.path.isfile(so_file):
ret = compile()
if ret != 0:
raise RuntimeError("tensorpack user_ops compilation failed!")
global _zmq_mod
_zmq_mod = tf.load_op_library(so_file)
try_build()
class ZMQPullSocket(object):
def __init__(self, end_point, types, hwm=None, bind=True, name=None):
self._types = types
assert isinstance(bind, bool), bind
if name is None:
self._name = (tf.get_default_graph()
.unique_name(self.__class__.__name__))
else:
self._name = name
self._zmq_handle = _zmq_mod.zmq_connection(
end_point, hwm, bind=bind, shared_name=self._name)
@property
def name(self):
return self._name
def pull(self):
return _zmq_mod.zmq_pull(
self._zmq_handle, self._types)
# copied from tensorflow/python/framework/dtypes.py
_DTYPE_DICT = {
np.float16: DT.DT_HALF,
np.float32: DT.DT_FLOAT,
np.float64: DT.DT_DOUBLE,
np.uint8: DT.DT_UINT8,
np.uint16: DT.DT_UINT16,
np.uint32: DT.DT_UINT32,
np.uint64: DT.DT_UINT64,
np.int64: DT.DT_INT64,
np.int32: DT.DT_INT32,
np.int16: DT.DT_INT16,
np.int8: DT.DT_INT8,
np.complex64: DT.DT_COMPLEX64,
np.complex128: DT.DT_COMPLEX128,
np.bool: DT.DT_BOOL,
}
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
def to_tensor_proto(arr):
"""
Convert a numpy array to TensorProto
Args:
arr: numpy.ndarray. only supports common numerical types
"""
if isinstance(arr, float):
arr = np.asarray(arr).astype('float32')
elif isinstance(arr, int):
arr = np.asarray(arr).astype('int32')
assert isinstance(arr, np.ndarray), type(arr)
try:
dtype = _DTYPE_DICT[arr.dtype]
except KeyError:
raise KeyError("Dtype {} is unsupported by current ZMQ Op!".format(arr.dtype))
ret = TensorProto()
shape = ret.tensor_shape
for s in arr.shape:
d = shape.dim.add()
d.size = s
ret.dtype = dtype
buf = arr.tobytes()
ret.tensor_content = buf
return ret
def dump_tensor_protos(protos):
"""
Serialize a list of :class:`TensorProto`, for communication between custom TensorFlow ops.
Args:
protos (list): list of :class:`TensorProto` instance
Notes:
The format is:
[#tensors(int32)]
[tensor1][tensor2]...
Where each tensor is:
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[len(buffer)(int64)][buffer]
"""
s = struct.pack('=i', len(protos))
for p in protos:
tensor_content = p.tensor_content
s += struct.pack('=i', int(p.dtype))
dims = p.tensor_shape.dim
s += struct.pack('=i', len(dims))
for k in dims:
s += struct.pack('=i', k.size)
s += struct.pack('=q', len(tensor_content))
s += tensor_content
return s
def dumps_zmq_op(dp):
"""
Dump a datapoint (list of nparray) into a format that the ZMQPull op in tensorpack would accept.
Args:
dp: list of nparray
Returns:
a binary string
"""
assert isinstance(dp, (list, tuple))
protos = [to_tensor_proto(arr) for arr in dp]
return dump_tensor_protos(protos)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment