Commit 273e6f91 authored by Yuxin Wu's avatar Yuxin Wu

add CacheData

parent 229f2dac
...@@ -48,7 +48,7 @@ class Model(GANModelDesc): ...@@ -48,7 +48,7 @@ class Model(GANModelDesc):
def generator(self, img): def generator(self, img):
with argscope([Conv2D, Deconv2D], with argscope([Conv2D, Deconv2D],
nl=BNLReLU, kernel_shape=4, stride=2, use_bias=False), \ nl=BNLReLU, kernel_shape=4, stride=2), \
argscope(Deconv2D, nl=BNReLU): argscope(Deconv2D, nl=BNReLU):
l = (LinearWrap(img) l = (LinearWrap(img)
.Conv2D('conv0', NF, nl=LeakyReLU) .Conv2D('conv0', NF, nl=LeakyReLU)
...@@ -107,8 +107,8 @@ class Model(GANModelDesc): ...@@ -107,8 +107,8 @@ class Model(GANModelDesc):
viz_A_recon = tf.concat([A, AB, ABA], axis=3, name='viz_A_recon') viz_A_recon = tf.concat([A, AB, ABA], axis=3, name='viz_A_recon')
viz_B_recon = tf.concat([B, BA, BAB], axis=3, name='viz_B_recon') viz_B_recon = tf.concat([B, BA, BAB], axis=3, name='viz_B_recon')
tf.summary.image('Arecon', tf.transpose(viz_A_recon, [0, 2, 3, 1]), max_outputs=30) tf.summary.image('Arecon', tf.transpose(viz_A_recon, [0, 2, 3, 1]), max_outputs=50)
tf.summary.image('Brecon', tf.transpose(viz_B_recon, [0, 2, 3, 1]), max_outputs=30) tf.summary.image('Brecon', tf.transpose(viz_B_recon, [0, 2, 3, 1]), max_outputs=50)
with tf.variable_scope('discrim'): with tf.variable_scope('discrim'):
with tf.variable_scope('A'): with tf.variable_scope('A'):
...@@ -220,7 +220,7 @@ if __name__ == '__main__': ...@@ -220,7 +220,7 @@ if __name__ == '__main__':
dataflow=data, dataflow=data,
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=1000,
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
......
...@@ -8,12 +8,12 @@ from termcolor import colored ...@@ -8,12 +8,12 @@ from termcolor import colored
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import logger, get_tqdm from ..utils import logger, get_tqdm, get_rng
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData', __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
'MapDataComponent', 'RepeatedData', 'RandomChooseData', 'MapDataComponent', 'RepeatedData', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent', 'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData'] 'LocallyShuffleData', 'CacheData']
class TestDataSpeed(ProxyDataFlow): class TestDataSpeed(ProxyDataFlow):
...@@ -499,6 +499,36 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -499,6 +499,36 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self._add_data() self._add_data()
class CacheData(ProxyDataFlow):
"""
Cache a dataflow completely in memory.
"""
def __init__(self, ds, shuffle=False):
"""
Args:
ds (DataFlow): input DataFlow.
shuffle (bool): whether to shuffle the datapoints before producing them.
"""
self.shuffle = shuffle
super(CacheData, self).__init__(ds)
def reset_state(self):
super(CacheData, self).reset_state()
if self.shuffle:
self.rng = get_rng(self)
self.buffer = []
def get_data(self):
if len(self.buffer):
self.rng.shuffle(self.buffer)
for dp in self.buffer:
yield dp
else:
for dp in self.ds.get_data():
yield dp
self.buffer.append(dp)
class PrintData(ProxyDataFlow): class PrintData(ProxyDataFlow):
""" """
Behave like an identity mapping but print shapes of produced datapoints once during construction. Behave like an identity mapping but print shapes of produced datapoints once during construction.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment