Commit 264b1e4e authored by Yuxin Wu's avatar Yuxin Wu

fix TFRecordSerializer when msgpack is used

parent b17a098d
......@@ -124,10 +124,17 @@ class TFRecordSerializer():
df (DataFlow): the DataFlow to serialize.
path (str): output tfrecord file.
"""
if os.environ.get('TENSORPACK_SERIALIZE', None) == 'msgpack':
def _dumps(dp):
return dumps(dp)
else:
def _dumps(dp):
return dumps(dp).to_pybytes()
size = _reset_df_and_get_size(df)
with tf.python_io.TFRecordWriter(path) as writer, get_tqdm(total=size) as pbar:
for dp in df.get_data():
writer.write(dumps(dp).to_pybytes())
writer.write(_dumps(dp))
pbar.update()
@staticmethod
......
......@@ -61,7 +61,7 @@ except ImportError:
dumps_msgpack = create_dummy_func( # noqa
'dumps_msgpack', ['msgpack', 'msgpack_numpy'])
if pa is None or os.environ.get('TENSORPACK_SERIALIZE', None) == 'msgpack':
if os.environ.get('TENSORPACK_SERIALIZE', None) == 'msgpack':
loads = loads_msgpack
dumps = dumps_msgpack
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment