Commit bab16832 authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] support more tensor types (#362)

parent d1ba5969
......@@ -8,7 +8,7 @@ import numpy as np
import os
from tensorflow.core.framework.tensor_pb2 import TensorProto
from tensorflow.core.framework import types_pb2 as DataType
from tensorflow.core.framework import types_pb2 as DT
# have to import like this: https://github.com/tensorflow/tensorflow/commit/955f038afbeb81302cea43058078e68574000bce
from .common import compile, get_ext_suffix
......@@ -58,12 +58,26 @@ class ZMQRecv(object):
self._zmq_handle, self._types)
# copied from tensorflow/python/framework/dtypes.py
_DTYPE_DICT = {
np.float32: DataType.DT_FLOAT,
np.float64: DataType.DT_DOUBLE,
np.int32: DataType.DT_INT32,
np.int8: DataType.DT_INT8,
np.uint8: DataType.DT_UINT8,
np.float16: DT.DT_HALF,
np.float32: DT.DT_FLOAT,
np.float64: DT.DT_DOUBLE,
np.uint8: DT.DT_UINT8,
np.uint16: DT.DT_UINT16,
np.uint32: DT.DT_UINT32,
np.uint64: DT.DT_UINT64,
np.int64: DT.DT_INT64,
np.int32: DT.DT_INT32,
np.int16: DT.DT_INT16,
np.int8: DT.DT_INT8,
np.complex64: DT.DT_COMPLEX64,
np.complex128: DT.DT_COMPLEX128,
np.bool: DT.DT_BOOL,
}
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment