Commit 8ef16f14 authored by Yuxin Wu's avatar Yuxin Wu

DCGAN example family, move TrainConfig to each script

parent 66726a8f
...@@ -118,16 +118,6 @@ def get_data(datadir): ...@@ -118,16 +118,6 @@ def get_data(datadir):
return ds return ds
def get_config():
return TrainConfig(
model=Model(),
dataflow=get_data(opt.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
)
def sample(model, model_path, output_name='gen/gen'): def sample(model, model_path, output_name='gen/gen'):
pred = PredictConfig( pred = PredictConfig(
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
...@@ -165,7 +155,12 @@ if __name__ == '__main__': ...@@ -165,7 +155,12 @@ if __name__ == '__main__':
else: else:
assert args.data assert args.data
logger.auto_set_dir() logger.auto_set_dir()
config = get_config() config = TrainConfig(
if args.load: model=Model(),
config.session_init = SaverRestore(args.load) dataflow=get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(config).train() GANTrainer(config).train()
...@@ -87,9 +87,6 @@ class Model(DCGAN.Model): ...@@ -87,9 +87,6 @@ class Model(DCGAN.Model):
return opt return opt
DCGAN.Model = Model
if __name__ == '__main__': if __name__ == '__main__':
args = DCGAN.get_args() args = DCGAN.get_args()
if args.sample: if args.sample:
...@@ -97,7 +94,12 @@ if __name__ == '__main__': ...@@ -97,7 +94,12 @@ if __name__ == '__main__':
else: else:
assert args.data assert args.data
logger.auto_set_dir() logger.auto_set_dir()
config = DCGAN.get_config() config = TrainConfig(
if args.load: model=Model(),
config.session_init = SaverRestore(args.load) dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
SeparateGANTrainer(config, g_period=6).train() SeparateGANTrainer(config, g_period=6).train()
...@@ -47,9 +47,6 @@ class Model(DCGAN.Model): ...@@ -47,9 +47,6 @@ class Model(DCGAN.Model):
return optimizer.VariableAssignmentOptimizer(opt, clip) return optimizer.VariableAssignmentOptimizer(opt, clip)
DCGAN.Model = Model
if __name__ == '__main__': if __name__ == '__main__':
args = DCGAN.get_args() args = DCGAN.get_args()
...@@ -58,11 +55,14 @@ if __name__ == '__main__': ...@@ -58,11 +55,14 @@ if __name__ == '__main__':
else: else:
assert args.data assert args.data
logger.auto_set_dir() logger.auto_set_dir()
config = DCGAN.get_config() config = TrainConfig(
config.steps_per_epoch = 500 model=Model(),
dataflow=DCGAN.get_data(args.data),
if args.load: callbacks=[ModelSaver()],
config.session_init = SaverRestore(args.load) steps_per_epoch=500,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
""" """
The original code uses a different schedule, but this seems to work well. The original code uses a different schedule, but this seems to work well.
""" """
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import six import six
import argparse import argparse
from . import logger
__all__ = ['globalns'] __all__ = ['globalns']
...@@ -26,6 +27,8 @@ class MyNS(NS): ...@@ -26,6 +27,8 @@ class MyNS(NS):
""" """
assert isinstance(args, argparse.Namespace), type(args) assert isinstance(args, argparse.Namespace), type(args)
for k, v in six.iteritems(vars(args)): for k, v in six.iteritems(vars(args)):
if hasattr(self, k):
logger.warn("Attribute {} in globalns will be overwritten!")
setattr(self, k, v) setattr(self, k, v)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment