Commit bf9da6d5 authored by Yuxin Wu's avatar Yuxin Wu

fix build

parent ab2cd7e6
......@@ -247,7 +247,7 @@ if __name__ == '__main__':
logger.auto_set_dir()
GANTrainer(QueueInput(get_data()),
Model()).train_with_defaults(
callbacks=[ModelSaver(keep_freq=0.1)],
callbacks=[ModelSaver(keep_checkpoint_every_n_hours=0.1)],
steps_per_epoch=500,
max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None
......
......@@ -8,6 +8,7 @@ import copy
import six
from six.moves import range
from .base import DataFlow, RNGDataFlow
from ..utils.develop import log_deprecated
__all__ = ['FakeData', 'DataFromQueue', 'DataFromList', 'DataFromGenerator']
......@@ -97,18 +98,21 @@ class DataFromList(RNGDataFlow):
class DataFromGenerator(DataFlow):
"""
Wrap a generator to a DataFlow
Wrap a generator to a DataFlow.
"""
def __init__(self, gen, size=None):
self._gen = gen
self._size = size
def size(self):
if self._size:
return self._size
return super(DataFromGenerator, self).size()
"""
Args:
gen: iterable, or a callable that returns an iterable
"""
if not callable(gen):
self._gen = lambda: gen
else:
self._gen = gen
if size is not None:
log_deprecated("DataFromGenerator(size=)", "It doesn't make much sense.")
def get_data(self):
# yield from
for dp in self._gen:
for dp in self._gen():
yield dp
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment