Commit ef1e9611 authored by Yuxin Wu's avatar Yuxin Wu

bugfix in lmdb datapoint loader

parent 8dbb84e2
......@@ -8,7 +8,6 @@ import tensorflow as tf
import argparse
import numpy as np
import multiprocessing
import msgpack
import os
import sys
......
......@@ -179,7 +179,7 @@ class BatchDataByShape(BatchData):
class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed size.
The state of the underlying DataFlow won't be reset when it's exhausted.
The iterator of the underlying DataFlow will be kept if not exhausted.
"""
def __init__(self, ds, size):
"""
......
......@@ -90,19 +90,17 @@ class LMDBData(RNGDataFlow):
with timed_operation("Loading LMDB keys ...", log_start=True), \
get_tqdm(total=size) as pbar:
for k in self._txn.cursor():
assert k[0] != '__keys__'
assert k[0] != b'__keys__'
keys.append(k[0])
pbar.update()
return keys
try:
self.keys = loads(self._txn.get('__keys__'))
except:
self.keys = None
else:
self.keys = self._txn.get(b'__keys__')
if self.keys is not None:
self.keys = loads(self.keys)
self._size -= 1 # delete this item
if self._shuffle:
if self._shuffle: # keys are necessary when shuffle is True
if keys is None:
if self.keys is None:
self.keys = find_keys(self._txn, self._size)
......@@ -133,7 +131,7 @@ class LMDBData(RNGDataFlow):
c = self._txn.cursor()
while c.next():
k, v = c.item()
if k != '__keys__':
if k != b'__keys__':
yield [k, v]
else:
self.rng.shuffle(self.keys)
......
......@@ -17,7 +17,7 @@ else:
__all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'):
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None):
"""
Run DataFlow and send data to a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket.
......@@ -26,11 +26,10 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'):
df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr.
hwm (int): high water mark
format (str): The serialization format.
'msgpack' is the default format corresponding to RemoteDataZMQ.
Otherwise will use the format corresponding to the ZMQRecv TensorFlow Op.
"""
dump_fn = dumps if format == 'msgpack' else dumps_for_tfop
# format (str): The serialization format. ZMQ Op is still not publicly usable now
# Default format would use :mod:`tensorpack.utils.serialize`.
dump_fn = dumps if format is None else dumps_for_tfop
ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH)
socket.set_hwm(hwm)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment