Commit 5bd3c395 authored by Yuxin Wu's avatar Yuxin Wu

[WIP] GAN trainers with new API

parent 2d99afea
...@@ -146,8 +146,6 @@ if __name__ == '__main__': ...@@ -146,8 +146,6 @@ if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
config = TrainConfig( config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
StatMonitorParamSetter( StatMonitorParamSetter(
...@@ -156,9 +154,12 @@ if __name__ == '__main__': ...@@ -156,9 +154,12 @@ if __name__ == '__main__':
steps_per_epoch=500, steps_per_epoch=500,
max_epoch=400, max_epoch=400,
session_init=SaverRestore(args.load) if args.load else None, session_init=SaverRestore(args.load) if args.load else None,
nr_tower=max(get_nr_gpu(), 1)
) )
if config.nr_tower == 1: input = QueueInput(DCGAN.get_data(args.data))
GANTrainer(config).train() model = Model()
nr_tower = max(get_nr_gpu(), 1)
if nr_tower == 1:
trainer = GANTrainer(input, model)
else: else:
MultiGPUGANTrainer(config).train() trainer = MultiGPUGANTrainer(nr_tower, input, model)
trainer.train_with_config(config)
...@@ -136,12 +136,15 @@ class MultiGPUGANTrainer(TowerTrainer): ...@@ -136,12 +136,15 @@ class MultiGPUGANTrainer(TowerTrainer):
input = StagingInput(input, list(range(nr_gpu))) input = StagingInput(input, list(range(nr_gpu)))
cbs = input.setup(model.get_inputs_desc()) cbs = input.setup(model.get_inputs_desc())
def get_cost(): def get_cost(*inputs):
model.build_graph(input.get_input_tensors()) model.build_graph(inputs)
return [model.d_loss, model.g_loss] return [model.d_loss, model.g_loss]
tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc()) tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(list(range(nr_gpu)), tower_func, devices) cost_list = DataParallelBuilder.build_on_towers(
list(range(nr_gpu)),
lambda: tower_func(*input.get_input_tensors()),
devices)
# simply average the cost. It might get faster to average the gradients # simply average the cost. It might get faster to average the gradients
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu) d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment