Commit 5bd3c395 authored by Yuxin Wu's avatar Yuxin Wu

[WIP] GAN trainers with new API

parent 2d99afea
......@@ -146,8 +146,6 @@ if __name__ == '__main__':
logger.auto_set_dir()
config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[
ModelSaver(),
StatMonitorParamSetter(
......@@ -156,9 +154,12 @@ if __name__ == '__main__':
steps_per_epoch=500,
max_epoch=400,
session_init=SaverRestore(args.load) if args.load else None,
nr_tower=max(get_nr_gpu(), 1)
)
if config.nr_tower == 1:
GANTrainer(config).train()
input = QueueInput(DCGAN.get_data(args.data))
model = Model()
nr_tower = max(get_nr_gpu(), 1)
if nr_tower == 1:
trainer = GANTrainer(input, model)
else:
MultiGPUGANTrainer(config).train()
trainer = MultiGPUGANTrainer(nr_tower, input, model)
trainer.train_with_config(config)
......@@ -136,12 +136,15 @@ class MultiGPUGANTrainer(TowerTrainer):
input = StagingInput(input, list(range(nr_gpu)))
cbs = input.setup(model.get_inputs_desc())
def get_cost():
model.build_graph(input.get_input_tensors())
def get_cost(*inputs):
model.build_graph(inputs)
return [model.d_loss, model.g_loss]
tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(list(range(nr_gpu)), tower_func, devices)
cost_list = DataParallelBuilder.build_on_towers(
list(range(nr_gpu)),
lambda: tower_func(*input.get_input_tensors()),
devices)
# simply average the cost. It might get faster to average the gradients
with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment