Commit f59d1377 authored by Yuxin Wu's avatar Yuxin Wu

fix LMDB write memory bug

parent 60200fc1
...@@ -34,7 +34,6 @@ To visualize on test set: ...@@ -34,7 +34,6 @@ To visualize on test set:
""" """
SHAPE = 256
BATCH = 1 BATCH = 1
IN_CH = 3 IN_CH = 3
OUT_CH = 3 OUT_CH = 3
...@@ -44,6 +43,7 @@ NF = 64 # number of filter ...@@ -44,6 +43,7 @@ NF = 64 # number of filter
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def _get_inputs(self):
SHAPE = 256
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'), return [InputDesc(tf.float32, (None, SHAPE, SHAPE, IN_CH), 'input'),
InputDesc(tf.float32, (None, SHAPE, SHAPE, OUT_CH), 'output')] InputDesc(tf.float32, (None, SHAPE, SHAPE, OUT_CH), 'output')]
...@@ -159,8 +159,7 @@ def get_data(): ...@@ -159,8 +159,7 @@ def get_data():
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0])) ds = MapData(ds, lambda dp: split_input(dp[0]))
assert SHAPE < 286 # this is the parameter used in the paper augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
augs = [imgaug.Resize(286), imgaug.RandomCrop(SHAPE)]
ds = AugmentImageComponents(ds, augs, (0, 1)) ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
ds = PrefetchData(ds, 100, 1) ds = PrefetchData(ds, 100, 1)
......
...@@ -6,8 +6,7 @@ import sys ...@@ -6,8 +6,7 @@ import sys
import os import os
import cv2 import cv2
import multiprocessing as mp import multiprocessing as mp
import six from six.moves import range
from six.moves import range, map
from .base import DataFlow from .base import DataFlow
from ..utils import get_tqdm, logger from ..utils import get_tqdm, logger
...@@ -68,15 +67,17 @@ def dump_dataflow_to_lmdb(ds, lmdb_path): ...@@ -68,15 +67,17 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with get_tqdm(total=sz) as pbar: with get_tqdm(total=sz) as pbar:
with db.begin(write=True) as txn: for idx, dp in enumerate(ds.get_data()):
for idx, dp in enumerate(ds.get_data()): with db.begin(write=True) as txn:
txn.put(six.binary_type(idx), dumps(dp)) txn.put(u'{}'.format(idx).encode('ascii'), dumps(dp))
pbar.update() pbar.update()
keys = list(map(six.binary_type, range(idx + 1))) keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
txn.put('__keys__', dumps(keys)) with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps(keys))
logger.info("Flushing database ...") logger.info("Flushing database ...")
db.sync() db.sync()
db.close()
try: try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment