Commit 552ec9e9 authored by Yuxin Wu's avatar Yuxin Wu

rename zmq ops

parent 04b52d20
...@@ -390,7 +390,7 @@ class ZMQInput(TensorInput): ...@@ -390,7 +390,7 @@ class ZMQInput(TensorInput):
self._hwm = int(hwm) self._hwm = int(hwm)
def fn(): def fn():
ret = self._zmq_recv_socket.recv() ret = self._zmq_pull_socket.pull()
assert len(ret) == len(self._desc) assert len(ret) == len(self._desc)
for qv, v in zip(ret, self._desc): for qv, v in zip(ret, self._desc):
qv.set_shape(v.shape) qv.set_shape(v.shape)
...@@ -402,8 +402,8 @@ class ZMQInput(TensorInput): ...@@ -402,8 +402,8 @@ class ZMQInput(TensorInput):
"ZMQInput has to be used with InputDesc!" "ZMQInput has to be used with InputDesc!"
self._desc = inputs_desc self._desc = inputs_desc
from ..user_ops import zmq_recv from ..user_ops import zmq_ops
self._zmq_recv_socket = zmq_recv.ZMQSocket( self._zmq_pull_socket = zmq_ops.ZMQPullSocket(
self._end_point, self._end_point,
[x.type for x in inputs_desc], [x.type for x in inputs_desc],
self._hwm) self._hwm)
......
# $File: Makefile # $File: Makefile
# $Date: Thu Dec 14 18:03:41 2017 -0800 # $Date: Thu Dec 21 14:12:30 2017 -0800
OBJ_DIR = obj OBJ_DIR = obj
PYTHON = python PYTHON = python
...@@ -45,7 +45,7 @@ OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o)) ...@@ -45,7 +45,7 @@ OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES = $(OBJS:.o=.d) DEPFILES = $(OBJS:.o=.d)
EXT_SUFFIX ?= $(shell $(PYTHON) -c 'import sysconfig; print(sysconfig.get_config_var("EXT_SUFFIX"))') EXT_SUFFIX ?= $(shell $(PYTHON) -c 'import sysconfig; print(sysconfig.get_config_var("EXT_SUFFIX"))')
SO = zmq_recv_op$(EXT_SUFFIX) SO = zmq_ops$(EXT_SUFFIX)
.PHONY: all clean .PHONY: all clean
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: test-recv-op.py # File: test-pull-op.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
...@@ -11,8 +11,8 @@ import time ...@@ -11,8 +11,8 @@ import time
import numpy as np import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa import tensorflow as tf # noqa
from tensorpack.user_ops.zmq_recv import ( # noqa from tensorpack.user_ops.zmq_ops import ( # noqa
ZMQSocket, dumps_zmq_op) ZMQPullSocket, dumps_zmq_op)
from tensorpack.utils.concurrency import ( # noqa from tensorpack.utils.concurrency import ( # noqa
start_proc_mask_signal, start_proc_mask_signal,
ensure_proc_terminate) ensure_proc_terminate)
...@@ -68,7 +68,7 @@ if __name__ == '__main__': ...@@ -68,7 +68,7 @@ if __name__ == '__main__':
start_proc_mask_signal(p) start_proc_mask_signal(p)
sess = tf.Session() sess = tf.Session()
recv = ZMQSocket(ENDPOINT, [tf.float32, tf.uint8]).recv() recv = ZMQPullSocket(ENDPOINT, [tf.float32, tf.uint8]).pull()
print(recv) print(recv)
for truth in DATA: for truth in DATA:
...@@ -87,9 +87,9 @@ if __name__ == '__main__': ...@@ -87,9 +87,9 @@ if __name__ == '__main__':
start_proc_mask_signal(p) start_proc_mask_signal(p)
sess = tf.Session() sess = tf.Session()
zmqsock = ZMQSocket(ENDPOINT, [tf.float32, tf.uint8], hwm=1) zmqsock = ZMQPullSocket(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv1 = zmqsock.recv() recv1 = zmqsock.pull()
recv2 = zmqsock.recv() recv2 = zmqsock.pull()
print(recv1, recv2) print(recv1, recv2)
for i in range(args.num // 2): for i in range(args.num // 2):
......
//File: zmq_recv_op.cc //File: zmq_ops.cc
//Author: Yuxin Wu <ppwwyyxxc@gmail.com> //Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#include <string> #include <string>
...@@ -41,9 +41,9 @@ class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> { ...@@ -41,9 +41,9 @@ class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> {
}; };
class ZMQRecvOp: public AsyncOpKernel { class ZMQPullOp: public AsyncOpKernel {
public: public:
explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) { explicit ZMQPullOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_)); OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
} }
...@@ -71,8 +71,8 @@ class ZMQRecvOp: public AsyncOpKernel { ...@@ -71,8 +71,8 @@ class ZMQRecvOp: public AsyncOpKernel {
TensorShape& shape = tensors[i].shape; TensorShape& shape = tensors[i].shape;
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done); OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done);
// reinterpret cast and then memcpy // reinterpret cast and then memcpy
auto ptr = output->bit_casted_shaped<char, 1>( auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}).data();
{shape.num_elements() * DataTypeSize(recv_dtype)}).data(); // {shape.num_elements() * DataTypeSize(recv_dtype)}).data();
memcpy(ptr, tensors[i].buf, tensors[i].buf_size); memcpy(ptr, tensors[i].buf, tensors[i].buf_size);
} }
done(); done();
...@@ -84,12 +84,12 @@ class ZMQRecvOp: public AsyncOpKernel { ...@@ -84,12 +84,12 @@ class ZMQRecvOp: public AsyncOpKernel {
}; };
REGISTER_KERNEL_BUILDER(Name("ZMQRecv").Device(DEVICE_CPU), ZMQRecvOp); REGISTER_KERNEL_BUILDER(Name("ZMQPull").Device(DEVICE_CPU), ZMQPullOp);
REGISTER_KERNEL_BUILDER(Name("ZMQConnection").Device(DEVICE_CPU), ZMQConnectionHandleOp); REGISTER_KERNEL_BUILDER(Name("ZMQConnection").Device(DEVICE_CPU), ZMQConnectionHandleOp);
} // namespace tensorpack } // namespace tensorpack
REGISTER_OP("ZMQRecv") REGISTER_OP("ZMQPull")
.Input("handle: resource") .Input("handle: resource")
.Output("output: types") .Output("output: types")
.Attr("types: list(type) >= 1") .Attr("types: list(type) >= 1")
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: zmq_recv.py # File: zmq_pull.py
import tensorflow as tf import tensorflow as tf
import struct import struct
...@@ -13,29 +13,29 @@ from tensorflow.core.framework import types_pb2 as DT ...@@ -13,29 +13,29 @@ from tensorflow.core.framework import types_pb2 as DT
from .common import compile, get_ext_suffix from .common import compile, get_ext_suffix
__all__ = ['dumps_zmq_op', 'ZMQSocket'] __all__ = ['dumps_zmq_op', 'ZMQPullSocket']
_zmq_recv_mod = None _zmq_mod = None
def try_build(): def try_build():
file_dir = os.path.dirname(os.path.abspath(__file__)) file_dir = os.path.dirname(os.path.abspath(__file__))
basename = 'zmq_recv_op' + get_ext_suffix() basename = 'zmq_ops' + get_ext_suffix()
so_file = os.path.join(file_dir, basename) so_file = os.path.join(file_dir, basename)
if not os.path.isfile(so_file): if not os.path.isfile(so_file):
ret = compile() ret = compile()
if ret != 0: if ret != 0:
raise RuntimeError("tensorpack user_ops compilation failed!") raise RuntimeError("tensorpack user_ops compilation failed!")
global _zmq_recv_mod global _zmq_mod
_zmq_recv_mod = tf.load_op_library(so_file) _zmq_mod = tf.load_op_library(so_file)
try_build() try_build()
class ZMQSocket(object): class ZMQPullSocket(object):
def __init__(self, end_point, types, hwm=None, bind=True, name=None): def __init__(self, end_point, types, hwm=None, bind=True, name=None):
self._types = types self._types = types
assert isinstance(bind, bool), bind assert isinstance(bind, bool), bind
...@@ -46,15 +46,15 @@ class ZMQSocket(object): ...@@ -46,15 +46,15 @@ class ZMQSocket(object):
else: else:
self._name = name self._name = name
self._zmq_handle = _zmq_recv_mod.zmq_connection( self._zmq_handle = _zmq_mod.zmq_connection(
end_point, hwm, bind=bind, shared_name=self._name) end_point, hwm, bind=bind, shared_name=self._name)
@property @property
def name(self): def name(self):
return self._name return self._name
def recv(self): def pull(self):
return _zmq_recv_mod.zmq_recv( return _zmq_mod.zmq_pull(
self._zmq_handle, self._types) self._zmq_handle, self._types)
...@@ -147,7 +147,7 @@ def dump_tensor_protos(protos): ...@@ -147,7 +147,7 @@ def dump_tensor_protos(protos):
def dumps_zmq_op(dp): def dumps_zmq_op(dp):
""" """
Dump a datapoint (list of nparray) into a format that the ZMQRecv op in tensorpack would accept. Dump a datapoint (list of nparray) into a format that the ZMQPull op in tensorpack would accept.
Args: Args:
dp: list of nparray dp: list of nparray
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment