Commit 01486c39 authored by Yuxin Wu's avatar Yuxin Wu

MultiGPU GAN Trainer

parent acf8e8f3
...@@ -12,7 +12,7 @@ from tensorpack.utils.globvars import globalns as G ...@@ -12,7 +12,7 @@ from tensorpack.utils.globvars import globalns as G
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf import tensorflow as tf
from GAN import GANModelDesc, GANTrainer from GAN import GANModelDesc, GANTrainer, MultiGPUGANTrainer
""" """
Boundary Equilibrium GAN. Boundary Equilibrium GAN.
...@@ -161,4 +161,9 @@ if __name__ == '__main__': ...@@ -161,4 +161,9 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
GANTrainer(config).train() nr_gpu = get_nr_gpu()
config.nr_tower = max(get_nr_gpu(), 1)
if config.nr_tower == 1:
GANTrainer(config).train()
else:
MultiGPUGANTrainer(config).train()
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import time import time
from tensorpack import (FeedfreeTrainerBase, QueueInput, ModelDesc, DataFlow) from tensorpack import (FeedfreeTrainerBase, QueueInput,
ModelDesc, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase, LeastLoadedDeviceSetter)
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
...@@ -17,7 +19,9 @@ class GANModelDesc(ModelDesc): ...@@ -17,7 +19,9 @@ class GANModelDesc(ModelDesc):
and same with self.d_vars. and same with self.d_vars.
""" """
self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope) self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope)
assert self.g_vars
self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope) self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope)
assert self.d_vars
def build_losses(self, logits_real, logits_fake): def build_losses(self, logits_real, logits_fake):
"""D and G play two-player minimax game with value function V(G,D) """D and G play two-player minimax game with value function V(G,D)
...@@ -56,7 +60,6 @@ class GANModelDesc(ModelDesc): ...@@ -56,7 +60,6 @@ class GANModelDesc(ModelDesc):
class GANTrainer(FeedfreeTrainerBase): class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config): def __init__(self, config):
# TODO design better
self._input_source = QueueInput(config.dataflow) self._input_source = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config) super(GANTrainer, self).__init__(config)
...@@ -105,6 +108,37 @@ class SeparateGANTrainer(FeedfreeTrainerBase): ...@@ -105,6 +108,37 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
self._cnt += 1 self._cnt += 1
class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
def __init__(self, config):
super(MultiGPUGANTrainer, self).__init__(config)
self._nr_gpu = config.nr_tower
assert self._nr_gpu > 1
self._raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
self._input_source = StagingInputWrapper(QueueInput(config.dataflow), self._raw_devices)
def _setup(self):
super(MultiGPUGANTrainer, self)._setup()
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices]
def get_cost():
self.build_train_tower()
return [self.model.d_loss, self.model.g_loss]
cost_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, get_cost, devices)
# simply average the cost. might be faster to average the gradients
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / self._nr_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / self._nr_gpu)
opt = self.model.get_optimizer()
# run one d_min after one g_min
self.g_min = opt.minimize(g_loss, var_list=self.model.g_vars,
colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([self.g_min]):
self.d_min = opt.minimize(d_loss, var_list=self.model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
self.train_op = self.d_min
class RandomZData(DataFlow): class RandomZData(DataFlow):
def __init__(self, shape): def __init__(self, shape):
super(RandomZData, self).__init__() super(RandomZData, self).__init__()
......
...@@ -64,7 +64,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -64,7 +64,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient""" """ get the cost and gradient"""
self.build_train_tower() self.build_train_tower()
cost = self.model.get_cost() # assume single cost cost = self.model.get_cost() # assume single cost
opt = self.config.optimizer opt = self.config.optimizer # TODO XXX
# GATE_NONE faster? # GATE_NONE faster?
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost, cost,
......
...@@ -393,6 +393,8 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -393,6 +393,8 @@ class StagingInputWrapper(FeedfreeInput):
self._stage_ops.append(stage.put(inputs)) self._stage_ops.append(stage.put(inputs))
self._areas.append(stage) self._areas.append(stage)
outputs = stage.get() outputs = stage.get()
if isinstance(outputs, tf.Tensor): # when size=1, TF doesn't return a list
outputs = [outputs]
for vin, vout in zip(inputs, outputs): for vin, vout in zip(inputs, outputs):
vout.set_shape(vin.get_shape()) vout.set_shape(vin.get_shape())
self._unstage_ops.append(outputs) self._unstage_ops.append(outputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment