Commit 6f6914af authored by Yuxin Wu's avatar Yuxin Wu

added zmq user op

parent 1ccb94f6
......@@ -23,7 +23,6 @@ DEPTH = None
class Model(ModelDesc):
def __init__(self, data_format='NCHW'):
self.data_format = data_format
......
# $File: Makefile
# $Date: Wed Jun 17 20:52:38 2015 +0800
OBJ_DIR = obj
CXX ?= g++
OPTFLAGS ?= -O3 -march=native
#OPTFLAGS ?= -g3 -fsanitize=address,undefined -O2 -lasan
#OPTFLAGS ?= -g3 -fsanitize=leak -O2 -lubsan
# optional extra packages
#LIBS = opencv
#INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS))
#LDFLAGS += $(shell pkg-config $(LIBS) --libs)
CXXFLAGS ?=
CXXFLAGS += $(INCLUDE_DIR)
CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-ignored-qualifiers
CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC
CXXFLAGS += -D_GLIBCXX_USE_CXX11_ABI=0 # https://github.com/tensorflow/tensorflow/issues/1569
LDFLAGS += $(OPTFLAGS)
LDFLAGS += -lzmq -lprotobuf $(OMP_FLAG)
LDFLAGS += -shared -fPIC
SHELL = bash
# sources to include
ccSOURCES = $(shell find $(SRCDIRS) -name "*.cc" | sed 's/^\.\///g')
OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES = $(OBJS:.o=.d)
SO = $(ccSOURCES:.cc=.so)
.PHONY: all clean
all: $(SO)
ifneq ($(MAKECMDGOALS), clean)
sinclude $(DEPFILES)
endif
%.so: $(OBJ_DIR)/%.o
@echo "Linking $@ ..."
@$(CXX) $^ -o $@ $(LDFLAGS)
@echo "done."
$(OBJ_DIR)/%.o: %.cc
@echo "[cc] $< ..."
@$(CXX) -c $< -o $@ $(CXXFLAGS)
$(OBJ_DIR)/%.d: %.cc Makefile
@mkdir -pv $(dir $@)
@echo "[dep] $< ..."
@$(CXX) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cc=.o) $(OBJ_DIR)/$(<:.cc=.d)" "$<" > "$@" || rm "$@"
clean:
@rm -rvf $(OBJ_DIR)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from __future__ import print_function
import tensorflow as tf
import os
__all__ = ['zmq_recv']
include_dir = tf.sysconfig.get_include()
file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'make INCLUDE_DIR=-I{} -C {}'.format(include_dir, file_dir)
print("Compiling user ops ...")
ret = os.system(compile_cmd)
if ret != 0:
print("Failed to compile user ops!")
recv_mod = tf.load_op_library(os.path.join(file_dir, 'zmq_recv_op.so'))
zmq_recv = recv_mod.zmq_recv
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: test-recv-op.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import sys
import os
import zmq
import multiprocessing as mp
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa
from tensorpack.utils.serialize import dump_tensor_protos, to_tensor_proto # noqa
from tensorpack.user_ops import zmq_recv # noqa
try:
num = int(sys.argv[1])
except:
num = 2
ENDPOINT = 'ipc://test-pipe'
DATA = []
for k in range(num):
arr1 = np.random.rand(k).astype('float32')
arr2 = (np.random.rand(k * 2) * 10).astype('uint8')
DATA.append([arr1, arr2])
def send():
ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH)
sok.connect(ENDPOINT)
for arr1, arr2 in DATA:
t1 = to_tensor_proto(arr1)
t2 = to_tensor_proto(arr2)
t = dump_tensor_protos([t1, t2])
sok.send(t)
def recv():
sess = tf.InteractiveSession()
recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8])
print(recv)
for truth in DATA:
arr = sess.run(recv)
assert (arr[0] == truth[0]).all()
assert (arr[1] == truth[1]).all()
p = mp.Process(target=send)
p.start()
recv()
p.join()
//File: zmq_conn.h
//Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#pragma once
#include <string>
#include <iostream>
#include <zmq.hpp>
#include <tensorflow/core/framework/tensor.pb.h>
namespace {
inline int read_int32(const char* p) {
auto pi = reinterpret_cast<const int*>(p);
return *pi;
}
}
class ZMQConnection {
public:
ZMQConnection(std::string endpoint, int zmq_socket_type):
ctx_(1), sock_(ctx_, zmq_socket_type) {
sock_.bind(endpoint.c_str());
}
tensorflow::TensorProto recv_tensor() {
zmq::message_t message;
bool succ = sock_.recv(&message);
CHECK(succ); // no EAGAIN, because we are blocking
tensorflow::TensorProto ret{};
CHECK(ret.ParseFromArray(message.data(), message.size()));
return ret;
}
std::vector<tensorflow::TensorProto> recv_tensor_list() {
zmq::message_t message;
// TODO critical section
bool succ = sock_.recv(&message);
CHECK(succ); // no EAGAIN, because we are blocking
char* pos = reinterpret_cast<char*>(message.data());
int num = read_int32(pos);
std::vector<tensorflow::TensorProto> ret(num);
pos += sizeof(int);
for (int i = 0; i < num; ++i) {
int size = read_int32(pos);
pos += sizeof(int);
//std::cout << "Message size:" << size << std::endl;
CHECK(ret[i].ParseFromArray(pos, size));
pos += size;
}
return ret;
}
private:
zmq::context_t ctx_;
zmq::socket_t sock_;
};
//File: recv_op.cc
//Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#include <string>
#include <memory>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "zmq_conn.h"
using namespace std;
using namespace tensorflow;
REGISTER_OP("ZMQRecv")
.Output("output: types")
.Attr("end_point: string")
.Attr("types: list(type) >= 1")
.SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful()
.Doc(R"doc(
Receive and return a serialized list of TensorProto from a ZMQ socket.
)doc");
class ZMQRecvOp: public OpKernel {
public:
explicit ZMQRecvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
CHECK(conn_.get() == nullptr);
string endpoint;
OP_REQUIRES_OK(context, context->GetAttr("end_point", &endpoint));
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL));
}
void Compute(OpKernelContext* ctx) override {
int start, stop;
TF_CHECK_OK(this->OutputRange("output", &start, &stop));
//cout << "COMPUTE" << endl;
auto protos = conn_->recv_tensor_list();
OpOutputList outputs;
OP_REQUIRES_OK(ctx, ctx->output_list("output", &outputs));
CHECK(protos.size() == num_components());
for (int i = start; i < stop; ++i) {
Tensor output;
int j = i - start;
OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto(
protos[j], ctx->output_alloc_attr(i), &output));
OP_REQUIRES(
ctx, component_types_[j] == output.dtype(),
errors::InvalidArgument("Type mismatch between parsed tensor (",
DataTypeString(output.dtype()), ") and dtype (",
DataTypeString(component_types_[j]), ")"));
outputs.set(j, output);
}
}
private:
DataTypeVector component_types_;
unique_ptr<ZMQConnection> conn_;
size_t num_components() const { return component_types_.size(); }
};
REGISTER_KERNEL_BUILDER(Name("ZMQRecv").Device(DEVICE_CPU), ZMQRecvOp);
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment