Commit 01486c39 authored by Yuxin Wu's avatar Yuxin Wu

MultiGPU GAN Trainer

parent acf8e8f3
......@@ -12,7 +12,7 @@ from tensorpack.utils.globvars import globalns as G
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
from GAN import GANModelDesc, GANTrainer
from GAN import GANModelDesc, GANTrainer, MultiGPUGANTrainer
"""
Boundary Equilibrium GAN.
......@@ -161,4 +161,9 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
GANTrainer(config).train()
nr_gpu = get_nr_gpu()
config.nr_tower = max(get_nr_gpu(), 1)
if config.nr_tower == 1:
GANTrainer(config).train()
else:
MultiGPUGANTrainer(config).train()
......@@ -6,7 +6,9 @@
import tensorflow as tf
import numpy as np
import time
from tensorpack import (FeedfreeTrainerBase, QueueInput, ModelDesc, DataFlow)
from tensorpack import (FeedfreeTrainerBase, QueueInput,
ModelDesc, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase, LeastLoadedDeviceSetter)
from tensorpack.tfutils.summary import add_moving_summary
......@@ -17,7 +19,9 @@ class GANModelDesc(ModelDesc):
and same with self.d_vars.
"""
self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope)
assert self.g_vars
self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope)
assert self.d_vars
def build_losses(self, logits_real, logits_fake):
"""D and G play two-player minimax game with value function V(G,D)
......@@ -56,7 +60,6 @@ class GANModelDesc(ModelDesc):
class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config):
# TODO design better
self._input_source = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config)
......@@ -105,6 +108,37 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
self._cnt += 1
class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
def __init__(self, config):
super(MultiGPUGANTrainer, self).__init__(config)
self._nr_gpu = config.nr_tower
assert self._nr_gpu > 1
self._raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
self._input_source = StagingInputWrapper(QueueInput(config.dataflow), self._raw_devices)
def _setup(self):
super(MultiGPUGANTrainer, self)._setup()
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices]
def get_cost():
self.build_train_tower()
return [self.model.d_loss, self.model.g_loss]
cost_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, get_cost, devices)
# simply average the cost. might be faster to average the gradients
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / self._nr_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / self._nr_gpu)
opt = self.model.get_optimizer()
# run one d_min after one g_min
self.g_min = opt.minimize(g_loss, var_list=self.model.g_vars,
colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([self.g_min]):
self.d_min = opt.minimize(d_loss, var_list=self.model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
self.train_op = self.d_min
class RandomZData(DataFlow):
def __init__(self, shape):
super(RandomZData, self).__init__()
......
......@@ -64,7 +64,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient"""
self.build_train_tower()
cost = self.model.get_cost() # assume single cost
opt = self.config.optimizer
opt = self.config.optimizer # TODO XXX
# GATE_NONE faster?
grads = opt.compute_gradients(
cost,
......
......@@ -393,6 +393,8 @@ class StagingInputWrapper(FeedfreeInput):
self._stage_ops.append(stage.put(inputs))
self._areas.append(stage)
outputs = stage.get()
if isinstance(outputs, tf.Tensor): # when size=1, TF doesn't return a list
outputs = [outputs]
for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment