Commit 8ef16f14 authored by Yuxin Wu's avatar Yuxin Wu

DCGAN example family, move TrainConfig to each script

parent 66726a8f
......@@ -118,16 +118,6 @@ def get_data(datadir):
return ds
def get_config():
return TrainConfig(
model=Model(),
dataflow=get_data(opt.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
)
def sample(model, model_path, output_name='gen/gen'):
pred = PredictConfig(
session_init=get_model_loader(model_path),
......@@ -165,7 +155,12 @@ if __name__ == '__main__':
else:
assert args.data
logger.auto_set_dir()
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config = TrainConfig(
model=Model(),
dataflow=get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(config).train()
......@@ -87,9 +87,6 @@ class Model(DCGAN.Model):
return opt
DCGAN.Model = Model
if __name__ == '__main__':
args = DCGAN.get_args()
if args.sample:
......@@ -97,7 +94,12 @@ if __name__ == '__main__':
else:
assert args.data
logger.auto_set_dir()
config = DCGAN.get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
SeparateGANTrainer(config, g_period=6).train()
......@@ -47,9 +47,6 @@ class Model(DCGAN.Model):
return optimizer.VariableAssignmentOptimizer(opt, clip)
DCGAN.Model = Model
if __name__ == '__main__':
args = DCGAN.get_args()
......@@ -58,11 +55,14 @@ if __name__ == '__main__':
else:
assert args.data
logger.auto_set_dir()
config = DCGAN.get_config()
config.steps_per_epoch = 500
if args.load:
config.session_init = SaverRestore(args.load)
config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=500,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
"""
The original code uses a different schedule, but this seems to work well.
"""
......
......@@ -5,6 +5,7 @@
import six
import argparse
from . import logger
__all__ = ['globalns']
......@@ -26,6 +27,8 @@ class MyNS(NS):
"""
assert isinstance(args, argparse.Namespace), type(args)
for k, v in six.iteritems(vars(args)):
if hasattr(self, k):
logger.warn("Attribute {} in globalns will be overwritten!")
setattr(self, k, v)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment