Commit e119ef75 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

simplify GAN usage (#917)

* simplify GAN usage

* move implementation to class

* small rename
parent e9363dfd
......@@ -212,7 +212,7 @@ if __name__ == '__main__':
df = get_data(args.data)
df = PrintData(df)
data = StagingInput(QueueInput(df))
data = QueueInput(df)
GANTrainer(data, Model()).train_with_defaults(
callbacks=[
......
......@@ -4,12 +4,13 @@
import tensorflow as tf
import numpy as np
from tensorpack import (TowerTrainer,
ModelDescBase, DataFlow, StagingInput)
from tensorpack import (TowerTrainer, StagingInput,
ModelDescBase, DataFlow)
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized
from tensorpack.utils.develop import deprecated
class GANModelDesc(ModelDescBase):
......@@ -73,7 +74,8 @@ class GANModelDesc(ModelDescBase):
class GANTrainer(TowerTrainer):
def __init__(self, input, model):
def __init__(self, input, model, num_gpu=1):
"""
Args:
input (InputSource):
......@@ -81,11 +83,20 @@ class GANTrainer(TowerTrainer):
"""
super(GANTrainer, self).__init__()
assert isinstance(model, GANModelDesc), model
inputs_desc = model.get_inputs_desc()
if num_gpu > 1:
input = StagingInput(input)
# Setup input
cbs = input.setup(inputs_desc)
cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs)
if num_gpu <= 1:
self._build_gan_trainer(input, model)
else:
self._build_multigpu_gan_trainer(input, model, num_gpu)
def _build_gan_trainer(self, input, model):
"""
We need to set tower_func because it's a TowerTrainer,
and only TowerTrainer supports automatic graph creation for inference during training.
......@@ -94,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK.
"""
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, inputs_desc)
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True):
self.tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
......@@ -107,6 +118,46 @@ class GANTrainer(TowerTrainer):
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min
def _build_multigpu_gan_trainer(self, input, model, num_gpu):
assert num_gpu > 1
raw_devices = ['/gpu:{}'.format(k) for k in range(num_gpu)]
# Build the graph with multi-gpu replication
def get_cost(*inputs):
model.build_graph(*inputs)
return [model.d_loss, model.g_loss]
self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(
list(range(num_gpu)),
lambda: self.tower_func(*input.get_input_tensors()),
devices)
# For simplicity, average the cost here. It might be faster to average the gradients
with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / num_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / num_gpu)
opt = model.get_optimizer()
# run one d_min after one g_min
g_min = opt.minimize(g_loss, var_list=model.g_vars,
colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([g_min]):
d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
# Define the training iteration
self.train_op = d_min
class MultiGPUGANTrainer(GANTrainer):
"""
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
@deprecated("Please use GANTrainer and set num_gpu", "2019-01-31")
def __init__(self, num_gpu, input, model):
super(MultiGPUGANTrainer, self).__init__(input, model, 1)
class SeparateGANTrainer(TowerTrainer):
""" A GAN trainer which runs two optimization ops with a certain ratio."""
......@@ -145,47 +196,6 @@ class SeparateGANTrainer(TowerTrainer):
self.hooked_sess.run(self.g_min)
class MultiGPUGANTrainer(TowerTrainer):
"""
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
def __init__(self, num_gpu, input, model):
super(MultiGPUGANTrainer, self).__init__()
assert num_gpu > 1
raw_devices = ['/gpu:{}'.format(k) for k in range(num_gpu)]
# Setup input
input = StagingInput(input)
cbs = input.setup(model.get_inputs_desc())
self.register_callback(cbs)
# Build the graph with multi-gpu replication
def get_cost(*inputs):
model.build_graph(*inputs)
return [model.d_loss, model.g_loss]
self.tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(
list(range(num_gpu)),
lambda: self.tower_func(*input.get_input_tensors()),
devices)
# For simplicity, average the cost here. It might be faster to average the gradients
with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / num_gpu)
g_loss = tf.add_n([x[1] for x in cost_list]) * (1.0 / num_gpu)
opt = model.get_optimizer()
# run one d_min after one g_min
g_min = opt.minimize(g_loss, var_list=model.g_vars,
colocate_gradients_with_ops=True, name='g_op')
with tf.control_dependencies([g_min]):
d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
# Define the training iteration
self.train_op = d_min
class RandomZData(DataFlow):
def __init__(self, shape):
super(RandomZData, self).__init__()
......
......@@ -16,7 +16,7 @@ from tensorpack.utils.gpu import get_num_gpu
from tensorpack.utils.viz import stack_patches
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from GAN import GANTrainer, MultiGPUGANTrainer, GANModelDesc
from GAN import GANTrainer, GANModelDesc
"""
To train Image-to-Image translation model with image pairs:
......@@ -217,12 +217,7 @@ if __name__ == '__main__':
logger.auto_set_dir()
data = QueueInput(get_data())
nr_tower = max(get_num_gpu(), 1)
if nr_tower == 1:
trainer = GANTrainer(data, Model())
else:
trainer = MultiGPUGANTrainer(nr_tower, data, Model())
trainer = GANTrainer(data, Model(), get_num_gpu())
trainer.train_with_defaults(
callbacks=[
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment