Commit f59d1377 authored by Yuxin Wu's avatar Yuxin Wu

fix LMDB write memory bug

parent 60200fc1
......@@ -34,7 +34,6 @@ To visualize on test set:
"""
SHAPE = 256
BATCH = 1
IN_CH = 3
OUT_CH = 3
......@@ -44,6 +43,7 @@ NF = 64 # number of filter
class Model(GANModelDesc):
def _get_inputs(self):
SHAPE = 256
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'),
InputDesc(tf.float32, (None, SHAPE, SHAPE, OUT_CH), 'output')]
......@@ -159,8 +159,7 @@ def get_data():
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0]))
assert SHAPE < 286 # this is the parameter used in the paper
augs = [imgaug.Resize(286), imgaug.RandomCrop(SHAPE)]
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH)
ds = PrefetchData(ds, 100, 1)
......
......@@ -6,8 +6,7 @@ import sys
import os
import cv2
import multiprocessing as mp
import six
from six.moves import range, map
from six.moves import range
from .base import DataFlow
from ..utils import get_tqdm, logger
......@@ -68,15 +67,17 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
except NotImplementedError:
sz = 0
with get_tqdm(total=sz) as pbar:
with db.begin(write=True) as txn:
for idx, dp in enumerate(ds.get_data()):
txn.put(six.binary_type(idx), dumps(dp))
for idx, dp in enumerate(ds.get_data()):
with db.begin(write=True) as txn:
txn.put(u'{}'.format(idx).encode('ascii'), dumps(dp))
pbar.update()
keys = list(map(six.binary_type, range(idx + 1)))
txn.put('__keys__', dumps(keys))
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps(keys))
logger.info("Flushing database ...")
db.sync()
logger.info("Flushing database ...")
db.sync()
db.close()
try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment