Commit 1ccb94f6 authored by Yuxin Wu's avatar Yuxin Wu

fix import in a nicer way

parent 9f64aa4c
...@@ -9,8 +9,8 @@ import msgpack_numpy ...@@ -9,8 +9,8 @@ import msgpack_numpy
import struct import struct
import numpy as np import numpy as np
from tensorflow.core.framework.tensor_pb2 import TensorProto from tensorflow.core.framework.tensor_pb2 import TensorProto
# import tensorflow.core.framework.types_pb2 as DataType from tensorflow.core.framework import types_pb2 as DataType
from tensorflow.core.framework.types_pb2 import * # noqa # have to import like this: https://github.com/tensorflow/tensorflow/commit/955f038afbeb81302cea43058078e68574000bce
msgpack_numpy.patch() msgpack_numpy.patch()
...@@ -37,11 +37,11 @@ def loads(buf): ...@@ -37,11 +37,11 @@ def loads(buf):
_DTYPE_DICT = { _DTYPE_DICT = {
np.float32: DT_FLOAT, # noqa np.float32: DataType.DT_FLOAT,
np.float64: DT_DOUBLE, # noqa np.float64: DataType.DT_DOUBLE,
np.int32: DT_INT32, # noqa np.int32: DataType.DT_INT32,
np.int8: DT_INT8, # noqa np.int8: DataType.DT_INT8,
np.uint8: DT_UINT8, # noqa np.uint8: DataType.DT_UINT8,
} }
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()} _DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment