Commit e119ef75 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

simplify GAN usage (#917)

* simplify GAN usage

* move implementation to class

* small rename
parent e9363dfd
...@@ -212,7 +212,7 @@ if __name__ == '__main__': ...@@ -212,7 +212,7 @@ if __name__ == '__main__':
df = get_data(args.data) df = get_data(args.data)
df = PrintData(df) df = PrintData(df)
data = StagingInput(QueueInput(df)) data = QueueInput(df)
GANTrainer(data, Model()).train_with_defaults( GANTrainer(data, Model()).train_with_defaults(
callbacks=[ callbacks=[
......
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from tensorpack import (TowerTrainer, from tensorpack import (TowerTrainer, StagingInput,
ModelDescBase, DataFlow, StagingInput) ModelDescBase, DataFlow)
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
from tensorpack.utils.develop import deprecated
class GANModelDesc(ModelDescBase): class GANModelDesc(ModelDescBase):
...@@ -73,7 +74,8 @@ class GANModelDesc(ModelDescBase): ...@@ -73,7 +74,8 @@ class GANModelDesc(ModelDescBase):
class GANTrainer(TowerTrainer): class GANTrainer(TowerTrainer):
def __init__(self, input, model):
def __init__(self, input, model, num_gpu=1):
""" """
Args: Args:
input (InputSource): input (InputSource):
...@@ -81,11 +83,20 @@ class GANTrainer(TowerTrainer): ...@@ -81,11 +83,20 @@ class GANTrainer(TowerTrainer):
""" """
super(GANTrainer, self).__init__() super(GANTrainer, self).__init__()
assert isinstance(model, GANModelDesc), model assert isinstance(model, GANModelDesc), model
inputs_desc = model.get_inputs_desc()
if num_gpu > 1:
input = StagingInput(input)
# Setup input # Setup input
cbs = input.setup(inputs_desc) cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs) self.register_callback(cbs)
if num_gpu <= 1:
self._build_gan_trainer(input, model)
else:
self._build_multigpu_gan_trainer(input, model, num_gpu)
def _build_gan_trainer(self, input, model):
""" """
We need to set tower_func because it's a TowerTrainer, We need to set tower_func because it's a TowerTrainer,
and only TowerTrainer supports automatic graph creation for inference during training. and only TowerTrainer supports automatic graph creation for inference during training.
...@@ -94,7 +105,7 @@ class GANTrainer(TowerTrainer): ...@@ -94,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK. not needed. Just calling model.build_graph directly is OK.
""" """
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc) self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer() opt = model.get_optimizer()
...@@ -107,6 +118,46 @@ class GANTrainer(TowerTrainer): ...@@ -107,6 +118,46 @@ class GANTrainer(TowerTrainer):
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op') d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min self.train_op = d_min
def _build_multigpu_gan_trainer(self, input, model, num_gpu):
assert num_gpu > 1
raw_devices = ['/gpu:{}'.format(k) for k in range(num_gpu)]
# Build the graph with multi-gpu replication
def get_cost(*inputs):
model.build_graph(*inputs)
return [model.d_loss, model.g_loss]
self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(
list(range(num_gpu)),
lambda: self.tower_func(*input.get_input_tensors()),
devices)
# For simplicity, average the cost here. It might be faster to average the gradients
with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / num_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / num_gpu)
opt = model.get_optimizer()
# run one d_min after one g_min
g_min = opt.minimize(g_loss, var_list=model.g_vars,
colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([g_min]):
d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
# Define the training iteration
self.train_op = d_min
class MultiGPUGANTrainer(GANTrainer):
"""
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
@deprecated("Please use GANTrainer and set num_gpu", "2019-01-31")
def __init__(self, num_gpu, input, model):
super(MultiGPUGANTrainer, self).__init__(input, model, 1)
class SeparateGANTrainer(TowerTrainer): class SeparateGANTrainer(TowerTrainer):
""" A GAN trainer which runs two optimization ops with a certain ratio.""" """ A GAN trainer which runs two optimization ops with a certain ratio."""
...@@ -145,47 +196,6 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -145,47 +196,6 @@ class SeparateGANTrainer(TowerTrainer):
self.hooked_sess.run(self.g_min) self.hooked_sess.run(self.g_min)
class MultiGPUGANTrainer(TowerTrainer):
"""
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
def __init__(self, num_gpu, input, model):
super(MultiGPUGANTrainer, self).__init__()
assert num_gpu > 1
raw_devices = ['/gpu:{}'.format(k) for k in range(num_gpu)]
# Setup input
input = StagingInput(input)
cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs)
# Build the graph with multi-gpu replication
def get_cost(*inputs):
model.build_graph(*inputs)
return [model.d_loss, model.g_loss]
self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(
list(range(num_gpu)),
lambda: self.tower_func(*input.get_input_tensors()),
devices)
# For simplicity, average the cost here. It might be faster to average the gradients
with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / num_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / num_gpu)
opt = model.get_optimizer()
# run one d_min after one g_min
g_min = opt.minimize(g_loss, var_list=model.g_vars,
colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([g_min]):
d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
# Define the training iteration
self.train_op = d_min
class RandomZData(DataFlow): class RandomZData(DataFlow):
def __init__(self, shape): def __init__(self, shape):
super(RandomZData, self).__init__() super(RandomZData, self).__init__()
......
...@@ -16,7 +16,7 @@ from tensorpack.utils.gpu import get_num_gpu ...@@ -16,7 +16,7 @@ from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils.viz import stack_patches from tensorpack.utils.viz import stack_patches
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from GAN import GANTrainer, MultiGPUGANTrainer, GANModelDesc from GAN import GANTrainer, GANModelDesc
""" """
To train Image-to-Image translation model with image pairs: To train Image-to-Image translation model with image pairs:
...@@ -217,12 +217,7 @@ if __name__ == '__main__': ...@@ -217,12 +217,7 @@ if __name__ == '__main__':
logger.auto_set_dir() logger.auto_set_dir()
data = QueueInput(get_data()) data = QueueInput(get_data())
trainer = GANTrainer(data, Model(), get_num_gpu())
nr_tower = max(get_num_gpu(), 1)
if nr_tower == 1:
trainer = GANTrainer(data, Model())
else:
trainer = MultiGPUGANTrainer(nr_tower, data, Model())
trainer.train_with_defaults( trainer.train_with_defaults(
callbacks=[ callbacks=[
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment