Commit 552ec9e9 authored by Yuxin Wu's avatar Yuxin Wu

rename zmq ops

parent 04b52d20
......@@ -390,7 +390,7 @@ class ZMQInput(TensorInput):
self._hwm = int(hwm)
def fn():
ret = self._zmq_recv_socket.recv()
ret = self._zmq_pull_socket.pull()
assert len(ret) == len(self._desc)
for qv, v in zip(ret, self._desc):
qv.set_shape(v.shape)
......@@ -402,8 +402,8 @@ class ZMQInput(TensorInput):
"ZMQInput has to be used with InputDesc!"
self._desc = inputs_desc
from ..user_ops import zmq_recv
self._zmq_recv_socket = zmq_recv.ZMQSocket(
from ..user_ops import zmq_ops
self._zmq_pull_socket = zmq_ops.ZMQPullSocket(
self._end_point,
[x.type for x in inputs_desc],
self._hwm)
......
# $File: Makefile
# $Date: Thu Dec 14 18:03:41 2017 -0800
# $Date: Thu Dec 21 14:12:30 2017 -0800
OBJ_DIR = obj
PYTHON = python
......@@ -45,7 +45,7 @@ OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES = $(OBJS:.o=.d)
EXT_SUFFIX ?= $(shell $(PYTHON) -c 'import sysconfig; print(sysconfig.get_config_var("EXT_SUFFIX"))')
SO = zmq_recv_op$(EXT_SUFFIX)
SO = zmq_ops$(EXT_SUFFIX)
.PHONY: all clean
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: test-recv-op.py
# File: test-pull-op.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
......@@ -11,8 +11,8 @@ import time
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa
from tensorpack.user_ops.zmq_recv import ( # noqa
ZMQSocket, dumps_zmq_op)
from tensorpack.user_ops.zmq_ops import ( # noqa
ZMQPullSocket, dumps_zmq_op)
from tensorpack.utils.concurrency import ( # noqa
start_proc_mask_signal,
ensure_proc_terminate)
......@@ -68,7 +68,7 @@ if __name__ == '__main__':
start_proc_mask_signal(p)
sess = tf.Session()
recv = ZMQSocket(ENDPOINT, [tf.float32, tf.uint8]).recv()
recv = ZMQPullSocket(ENDPOINT, [tf.float32, tf.uint8]).pull()
print(recv)
for truth in DATA:
......@@ -87,9 +87,9 @@ if __name__ == '__main__':
start_proc_mask_signal(p)
sess = tf.Session()
zmqsock = ZMQSocket(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv1 = zmqsock.recv()
recv2 = zmqsock.recv()
zmqsock = ZMQPullSocket(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv1 = zmqsock.pull()
recv2 = zmqsock.pull()
print(recv1, recv2)
for i in range(args.num // 2):
......
//File: zmq_recv_op.cc
//File: zmq_ops.cc
//Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#include <string>
......@@ -41,9 +41,9 @@ class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> {
};
class ZMQRecvOp: public AsyncOpKernel {
class ZMQPullOp: public AsyncOpKernel {
public:
explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
explicit ZMQPullOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
}
......@@ -71,8 +71,8 @@ class ZMQRecvOp: public AsyncOpKernel {
TensorShape& shape = tensors[i].shape;
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done);
// reinterpret cast and then memcpy
auto ptr = output->bit_casted_shaped<char, 1>(
{shape.num_elements() * DataTypeSize(recv_dtype)}).data();
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}).data();
// {shape.num_elements() * DataTypeSize(recv_dtype)}).data();
memcpy(ptr, tensors[i].buf, tensors[i].buf_size);
}
done();
......@@ -84,12 +84,12 @@ class ZMQRecvOp: public AsyncOpKernel {
};
REGISTER_KERNEL_BUILDER(Name("ZMQRecv").Device(DEVICE_CPU), ZMQRecvOp);
REGISTER_KERNEL_BUILDER(Name("ZMQPull").Device(DEVICE_CPU), ZMQPullOp);
REGISTER_KERNEL_BUILDER(Name("ZMQConnection").Device(DEVICE_CPU), ZMQConnectionHandleOp);
} // namespace tensorpack
REGISTER_OP("ZMQRecv")
REGISTER_OP("ZMQPull")
.Input("handle: resource")
.Output("output: types")
.Attr("types: list(type) >= 1")
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: zmq_recv.py
# File: zmq_pull.py
import tensorflow as tf
import struct
......@@ -13,29 +13,29 @@ from tensorflow.core.framework import types_pb2 as DT
from .common import compile, get_ext_suffix
__all__ = ['dumps_zmq_op', 'ZMQSocket']
__all__ = ['dumps_zmq_op', 'ZMQPullSocket']
_zmq_recv_mod = None
_zmq_mod = None
def try_build():
file_dir = os.path.dirname(os.path.abspath(__file__))
basename = 'zmq_recv_op' + get_ext_suffix()
basename = 'zmq_ops' + get_ext_suffix()
so_file = os.path.join(file_dir, basename)
if not os.path.isfile(so_file):
ret = compile()
if ret != 0:
raise RuntimeError("tensorpack user_ops compilation failed!")
global _zmq_recv_mod
_zmq_recv_mod = tf.load_op_library(so_file)
global _zmq_mod
_zmq_mod = tf.load_op_library(so_file)
try_build()
class ZMQSocket(object):
class ZMQPullSocket(object):
def __init__(self, end_point, types, hwm=None, bind=True, name=None):
self._types = types
assert isinstance(bind, bool), bind
......@@ -46,15 +46,15 @@ class ZMQSocket(object):
else:
self._name = name
self._zmq_handle = _zmq_recv_mod.zmq_connection(
self._zmq_handle = _zmq_mod.zmq_connection(
end_point, hwm, bind=bind, shared_name=self._name)
@property
def name(self):
return self._name
def recv(self):
return _zmq_recv_mod.zmq_recv(
def pull(self):
return _zmq_mod.zmq_pull(
self._zmq_handle, self._types)
......@@ -147,7 +147,7 @@ def dump_tensor_protos(protos):
def dumps_zmq_op(dp):
"""
Dump a datapoint (list of nparray) into a format that the ZMQRecv op in tensorpack would accept.
Dump a datapoint (list of nparray) into a format that the ZMQPull op in tensorpack would accept.
Args:
dp: list of nparray
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment