Commit b03e70cb authored by Yuxin Wu's avatar Yuxin Wu

fix BEGAN main script API

parent 795f016a
......@@ -145,16 +145,6 @@ if __name__ == '__main__':
assert args.data
logger.auto_set_dir()
config = TrainConfig(
callbacks=[
ModelSaver(),
StatMonitorParamSetter(
'learning_rate', 'measure', lambda x: x * 0.5, 0, 10)
],
steps_per_epoch=500,
max_epoch=400,
session_init=SaverRestore(args.load) if args.load else None,
)
input = QueueInput(DCGAN.get_data(args.data))
model = Model()
nr_tower = max(get_nr_gpu(), 1)
......@@ -162,4 +152,12 @@ if __name__ == '__main__':
trainer = GANTrainer(input, model)
else:
trainer = MultiGPUGANTrainer(nr_tower, input, model)
trainer.train_with_config(config)
trainer.train_with_defaults(
callbacks=[
ModelSaver(),
StatMonitorParamSetter(
'learning_rate', 'measure', lambda x: x * 0.5, 0, 10)
],
session_init=SaverRestore(args.load) if args.load else None,
steps_per_epoch=500, max_epoch=400)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment