Commit 264b1e4e authored by Yuxin Wu's avatar Yuxin Wu

fix TFRecordSerializer when msgpack is used

parent b17a098d
...@@ -124,10 +124,17 @@ class TFRecordSerializer(): ...@@ -124,10 +124,17 @@ class TFRecordSerializer():
df (DataFlow): the DataFlow to serialize. df (DataFlow): the DataFlow to serialize.
path (str): output tfrecord file. path (str): output tfrecord file.
""" """
if os.environ.get('TENSORPACK_SERIALIZE', None) == 'msgpack':
def _dumps(dp):
return dumps(dp)
else:
def _dumps(dp):
return dumps(dp).to_pybytes()
size = _reset_df_and_get_size(df) size = _reset_df_and_get_size(df)
with tf.python_io.TFRecordWriter(path) as writer, get_tqdm(total=size) as pbar: with tf.python_io.TFRecordWriter(path) as writer, get_tqdm(total=size) as pbar:
for dp in df.get_data(): for dp in df.get_data():
writer.write(dumps(dp).to_pybytes()) writer.write(_dumps(dp))
pbar.update() pbar.update()
@staticmethod @staticmethod
......
...@@ -61,7 +61,7 @@ except ImportError: ...@@ -61,7 +61,7 @@ except ImportError:
dumps_msgpack = create_dummy_func( # noqa dumps_msgpack = create_dummy_func( # noqa
'dumps_msgpack', ['msgpack', 'msgpack_numpy']) 'dumps_msgpack', ['msgpack', 'msgpack_numpy'])
if pa is None or os.environ.get('TENSORPACK_SERIALIZE', None) == 'msgpack': if os.environ.get('TENSORPACK_SERIALIZE', None) == 'msgpack':
loads = loads_msgpack loads = loads_msgpack
dumps = dumps_msgpack dumps = dumps_msgpack
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment