Commit ef1e9611 authored by Yuxin Wu's avatar Yuxin Wu

bugfix in lmdb datapoint loader

parent 8dbb84e2
...@@ -8,7 +8,6 @@ import tensorflow as tf ...@@ -8,7 +8,6 @@ import tensorflow as tf
import argparse import argparse
import numpy as np import numpy as np
import multiprocessing import multiprocessing
import msgpack
import os import os
import sys import sys
......
...@@ -179,7 +179,7 @@ class BatchDataByShape(BatchData): ...@@ -179,7 +179,7 @@ class BatchDataByShape(BatchData):
class FixedSizeData(ProxyDataFlow): class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed size. """ Generate data from another DataFlow, but with a fixed size.
The state of the underlying DataFlow won't be reset when it's exhausted. The iterator of the underlying DataFlow will be kept if not exhausted.
""" """
def __init__(self, ds, size): def __init__(self, ds, size):
""" """
......
...@@ -90,19 +90,17 @@ class LMDBData(RNGDataFlow): ...@@ -90,19 +90,17 @@ class LMDBData(RNGDataFlow):
with timed_operation("Loading LMDB keys ...", log_start=True), \ with timed_operation("Loading LMDB keys ...", log_start=True), \
get_tqdm(total=size) as pbar: get_tqdm(total=size) as pbar:
for k in self._txn.cursor(): for k in self._txn.cursor():
assert k[0] != '__keys__' assert k[0] != b'__keys__'
keys.append(k[0]) keys.append(k[0])
pbar.update() pbar.update()
return keys return keys
try: self.keys = self._txn.get(b'__keys__')
self.keys = loads(self._txn.get('__keys__')) if self.keys is not None:
except: self.keys = loads(self.keys)
self.keys = None
else:
self._size -= 1 # delete this item self._size -= 1 # delete this item
if self._shuffle: if self._shuffle: # keys are necessary when shuffle is True
if keys is None: if keys is None:
if self.keys is None: if self.keys is None:
self.keys = find_keys(self._txn, self._size) self.keys = find_keys(self._txn, self._size)
...@@ -133,7 +131,7 @@ class LMDBData(RNGDataFlow): ...@@ -133,7 +131,7 @@ class LMDBData(RNGDataFlow):
c = self._txn.cursor() c = self._txn.cursor()
while c.next(): while c.next():
k, v = c.item() k, v = c.item()
if k != '__keys__': if k != b'__keys__':
yield [k, v] yield [k, v]
else: else:
self.rng.shuffle(self.keys) self.rng.shuffle(self.keys)
......
...@@ -17,7 +17,7 @@ else: ...@@ -17,7 +17,7 @@ else:
__all__ = ['send_dataflow_zmq', 'RemoteDataZMQ'] __all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'): def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None):
""" """
Run DataFlow and send data to a ZMQ socket addr. Run DataFlow and send data to a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket. It will dump and send each datapoint to this addr with a PUSH socket.
...@@ -26,11 +26,10 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'): ...@@ -26,11 +26,10 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'):
df (DataFlow): Will infinitely loop over the DataFlow. df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr. addr: a ZMQ socket addr.
hwm (int): high water mark hwm (int): high water mark
format (str): The serialization format.
'msgpack' is the default format corresponding to RemoteDataZMQ.
Otherwise will use the format corresponding to the ZMQRecv TensorFlow Op.
""" """
dump_fn = dumps if format == 'msgpack' else dumps_for_tfop # format (str): The serialization format. ZMQ Op is still not publicly usable now
# Default format would use :mod:`tensorpack.utils.serialize`.
dump_fn = dumps if format is None else dumps_for_tfop
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH) socket = ctx.socket(zmq.PUSH)
socket.set_hwm(hwm) socket.set_hwm(hwm)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment