Commit 273e6f91 authored by Yuxin Wu's avatar Yuxin Wu

add CacheData

parent 229f2dac
......@@ -48,7 +48,7 @@ class Model(GANModelDesc):
def generator(self, img):
with argscope([Conv2D, Deconv2D],
nl=BNLReLU, kernel_shape=4, stride=2, use_bias=False), \
nl=BNLReLU, kernel_shape=4, stride=2), \
argscope(Deconv2D, nl=BNReLU):
l = (LinearWrap(img)
.Conv2D('conv0', NF, nl=LeakyReLU)
......@@ -107,8 +107,8 @@ class Model(GANModelDesc):
viz_A_recon = tf.concat([A, AB, ABA], axis=3, name='viz_A_recon')
viz_B_recon = tf.concat([B, BA, BAB], axis=3, name='viz_B_recon')
tf.summary.image('Arecon', tf.transpose(viz_A_recon, [0, 2, 3, 1]), max_outputs=30)
tf.summary.image('Brecon', tf.transpose(viz_B_recon, [0, 2, 3, 1]), max_outputs=30)
tf.summary.image('Arecon', tf.transpose(viz_A_recon, [0, 2, 3, 1]), max_outputs=50)
tf.summary.image('Brecon', tf.transpose(viz_B_recon, [0, 2, 3, 1]), max_outputs=50)
with tf.variable_scope('discrim'):
with tf.variable_scope('A'):
......@@ -220,7 +220,7 @@ if __name__ == '__main__':
dataflow=data,
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
max_epoch=1000,
session_init=SaverRestore(args.load) if args.load else None
)
......
......@@ -8,12 +8,12 @@ from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import logger, get_tqdm
from ..utils import logger, get_tqdm, get_rng
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
'MapDataComponent', 'RepeatedData', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData']
'LocallyShuffleData', 'CacheData']
class TestDataSpeed(ProxyDataFlow):
......@@ -499,6 +499,36 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self._add_data()
class CacheData(ProxyDataFlow):
"""
Cache a dataflow completely in memory.
"""
def __init__(self, ds, shuffle=False):
"""
Args:
ds (DataFlow): input DataFlow.
shuffle (bool): whether to shuffle the datapoints before producing them.
"""
self.shuffle = shuffle
super(CacheData, self).__init__(ds)
def reset_state(self):
super(CacheData, self).reset_state()
if self.shuffle:
self.rng = get_rng(self)
self.buffer = []
def get_data(self):
if len(self.buffer):
self.rng.shuffle(self.buffer)
for dp in self.buffer:
yield dp
else:
for dp in self.ds.get_data():
yield dp
self.buffer.append(dp)
class PrintData(ProxyDataFlow):
"""
Behave like an identity mapping but print shapes of produced datapoints once during construction.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment