Commit 1ccb94f6 authored by Yuxin Wu's avatar Yuxin Wu

fix import in a nicer way

parent 9f64aa4c
......@@ -9,8 +9,8 @@ import msgpack_numpy
import struct
import numpy as np
from tensorflow.core.framework.tensor_pb2 import TensorProto
# import tensorflow.core.framework.types_pb2 as DataType
from tensorflow.core.framework.types_pb2 import * # noqa
from tensorflow.core.framework import types_pb2 as DataType
# have to import like this: https://github.com/tensorflow/tensorflow/commit/955f038afbeb81302cea43058078e68574000bce
msgpack_numpy.patch()
......@@ -37,11 +37,11 @@ def loads(buf):
_DTYPE_DICT = {
np.float32: DT_FLOAT, # noqa
np.float64: DT_DOUBLE, # noqa
np.int32: DT_INT32, # noqa
np.int8: DT_INT8, # noqa
np.uint8: DT_UINT8, # noqa
np.float32: DataType.DT_FLOAT,
np.float64: DataType.DT_DOUBLE,
np.int32: DataType.DT_INT32,
np.int8: DataType.DT_INT8,
np.uint8: DataType.DT_UINT8,
}
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment