Commit 10113750 authored by Yuxin Wu's avatar Yuxin Wu

add TensorProto serialization support

parent 59d82d13
......@@ -6,7 +6,7 @@ All tutorials are drafts for now. You can get an idea from them but the details
might not be correct.
.. toctree::
:maxdepth: 2
:maxdepth: 3
tutorial/index
casestudies/index
......
......@@ -7,20 +7,20 @@ A High Level Glance
* :doc:`dataflow` is a set of extensible tools to help you define your input data with ease and speed.
It provides a uniformed interface so data processing modules can be chained together.
It allows you to load and process your data in pure Python and accelerate it by multiprocess prefetch.
It provides a uniform interface, so data processing modules can be chained together.
It allows you to load and process your data in pure Python and accelerate it by prefetching.
See also :doc:`tf-queue` and :doc:`efficient-dataflow` for more details about efficiency of data
processing.
* You can use any TF-based symbolic function library to define a model in tensorpack.
:doc:`model` introduces where and how you define the model for tensorpack trainers to use,
and how you can benefit from the symbolic function library in tensorpack.
and how you can benefit from the small symbolic function library in tensorpack.
Both DataFlow and models can be used outside tensorpack, as just a data processing library and a symbolic
function library. Tensopack trainers integrate these two components and add more convenient features.
* tensorpack :doc:`trainer` manages the training loops for you so you won't have to worry about
details such as multi-GPU training. At the same time it keeps the power of customization to you
details such as multi-GPU training. At the same time it keeps the power of customization
through callbacks.
* Callbacks are like ``tf.train.SessionRunHook``, or plugins, or extensions. During training,
......
......@@ -5,8 +5,15 @@
import msgpack
import msgpack_numpy
import struct
import numpy as np
from tensorflow.core.framework.tensor_pb2 import TensorProto
import tensorflow.core.framework.types_pb2 as DataType
msgpack_numpy.patch()
__all__ = ['loads', 'dumps']
......@@ -26,3 +33,55 @@ def loads(buf):
buf (str): serialized object.
"""
return msgpack.loads(buf)
_DTYPE_DICT = {
np.float32: DataType.DT_FLOAT,
np.float64: DataType.DT_DOUBLE,
np.int32: DataType.DT_INT32,
np.int8: DataType.DT_INT8,
np.uint8: DataType.DT_UINT8,
}
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
# TODO support string tensor
def to_tensor_proto(arr):
"""
Convert a numpy array to TensorProto
Args:
arr: numpy.ndarray. only supports common numerical types
"""
dtype = _DTYPE_DICT[arr.dtype]
ret = TensorProto()
shape = ret.tensor_shape
for s in arr.shape:
d = shape.dim.add()
d.size = s
ret.dtype = dtype
buf = arr.tobytes()
ret.tensor_content = buf
return ret
def dump_tensor_protos(protos):
"""
Serialize a list of :class:`TensorProto`, for communication between custom TensorFlow ops.
Args:
protos (list): list of :class:`TensorProto` instance
Notes:
The format is: <#protos(int32)>|<size 1>|<serialized proto 1>|<size 2><serialized proto 2>| ...
"""
s = struct.pack('=i', len(protos))
for p in protos:
buf = p.SerializeToString()
s += struct.pack('=i', len(buf)) # won't send stuff over 2G
s += buf
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment