Commit 2d99afea authored by Yuxin Wu's avatar Yuxin Wu

[WIP] Switch GANs to use Trainerv2

parent 17a73a4c
......@@ -6,6 +6,7 @@
import os
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_nr_gpu
......
......@@ -10,6 +10,7 @@ import sys
import cv2
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf
......
......@@ -9,6 +9,7 @@ import glob
from six.moves import map, zip, range
import numpy as np
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf
......
......@@ -8,6 +8,7 @@ import numpy as np
import os, sys
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary
......@@ -156,11 +157,11 @@ if __name__ == '__main__':
assert args.data
logger.auto_set_dir()
config = TrainConfig(
model=Model(),
dataflow=get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
)
GANTrainer(config).train()
GANTrainer(
input=QueueInput(get_data(args.data)),
model=Model()).train_with_config(config)
......@@ -8,6 +8,7 @@ import argparse
from six.moves import map, zip
import numpy as np
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf
......
......@@ -6,9 +6,9 @@
import tensorflow as tf
import numpy as np
import time
from tensorpack import (Trainer, QueueInput,
from tensorpack import (TowerTrainer, QueueInput,
ModelDescBase, DataFlow, StagingInput,
TowerContext)
TowerContext, TowerFuncWrapper)
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized
......@@ -64,20 +64,15 @@ class GANModelDesc(ModelDescBase):
return self._get_optimizer()
class GANTrainer(Trainer):
def __init__(self, config):
"""
GANTrainer expects a ModelDesc in config which sets the following attribute
after :meth:`_build_graph`: g_loss, d_loss, g_vars, d_vars.
"""
input = QueueInput(config.dataflow)
model = config.model
class GANTrainer(TowerTrainer):
def __init__(self, input, model):
super(GANTrainer, self).__init__()
assert isinstance(model, GANModelDesc), model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True):
model.build_graph(input)
tower_func(input)
opt = model.get_optimizer()
# by default, run one d_min after one g_min
......@@ -86,29 +81,29 @@ class GANTrainer(Trainer):
with tf.control_dependencies([g_min]):
d_min = opt.minimize(model.d_loss, var_list=model.d_vars, name='d_op')
self.train_op = d_min
self.set_tower_func(tower_func)
super(GANTrainer, self).__init__(config)
for cb in cbs:
self._register_callback(cb)
class SeparateGANTrainer(Trainer):
class SeparateGANTrainer(TowerTrainer):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
def __init__(self, config, d_period=1, g_period=1):
def __init__(self, input, model, d_period=1, g_period=1):
"""
Args:
d_period(int): period of each d_opt run
g_period(int): period of each g_opt run
"""
super(SeparateGANTrainer, self).__init__()
self._d_period = int(d_period)
self._g_period = int(g_period)
assert min(d_period, g_period) == 1
input = QueueInput(config.dataflow)
model = config.model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True):
model.build_graph(input)
tower_func(input)
opt = model.get_optimizer()
with tf.name_scope('optimize'):
......@@ -117,7 +112,9 @@ class SeparateGANTrainer(Trainer):
self.g_min = opt.minimize(
model.g_loss, var_list=model.g_vars, name='g_min')
super(SeparateGANTrainer, self).__init__(config)
self.set_tower_func(tower_func)
for cb in cbs:
self._register_callback(cb)
def run_step(self):
if self.global_step % (self._d_period) == 0:
......@@ -126,26 +123,25 @@ class SeparateGANTrainer(Trainer):
self.hooked_sess.run(self.g_min)
class MultiGPUGANTrainer(Trainer):
class MultiGPUGANTrainer(TowerTrainer):
"""
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
def __init__(self, config):
nr_gpu = config.nr_tower
def __init__(self, nr_gpu, input, model):
super(MultiGPUGANTrainer, self).__init__()
assert nr_gpu > 1
raw_devices = ['/gpu:{}'.format(k) for k in config.tower]
raw_devices = ['/gpu:{}'.format(k) for k in range(nr_gpu)]
# setup input
input = StagingInput(QueueInput(config.dataflow), config.tower)
model = config.model
input = StagingInput(input, list(range(nr_gpu)))
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
def get_cost():
model.build_graph(input)
model.build_graph(input.get_input_tensors())
return [model.d_loss, model.g_loss]
tower_func = TowerFuncWrapper(get_cost, model.get_inputs_desc())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(config.tower, get_cost, devices)
cost_list = DataParallelBuilder.build_on_towers(list(range(nr_gpu)), tower_func, devices)
# simply average the cost. It might get faster to average the gradients
with tf.name_scope('optimize'):
d_loss = tf.add_n([x[0] for x in cost_list]) * (1.0 / nr_gpu)
......@@ -159,7 +155,9 @@ class MultiGPUGANTrainer(Trainer):
d_min = opt.minimize(d_loss, var_list=model.d_vars,
colocate_gradients_with_ops=True, name='d_op')
self.train_op = d_min
super(MultiGPUGANTrainer, self).__init__(config)
self.set_tower_func(tower_func)
for cb in cbs:
self._register_callback(cb)
class RandomZData(DataFlow):
......
......@@ -12,6 +12,7 @@ import os
import sys
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary
......
......@@ -6,6 +6,7 @@
import os
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.globvars import globalns as G
......
......@@ -10,6 +10,7 @@ import os
import sys
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.utils import viz
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
......
......@@ -6,6 +6,7 @@
import os
import argparse
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import add_moving_summary
......@@ -76,8 +77,6 @@ if __name__ == '__main__':
assert args.data
logger.auto_set_dir()
config = TrainConfig(
model=Model(),
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver(), ClipCallback()],
steps_per_epoch=500,
max_epoch=200,
......@@ -85,4 +84,6 @@ if __name__ == '__main__':
)
# The original code uses a different schedule, but this seems to work well.
# Train 1 D after 2 G
SeparateGANTrainer(config, d_period=3).train()
SeparateGANTrainer(
input=QueueInput(DCGAN.get_data(args.data)),
model=Model(), d_period=3).train_with_config(config)
......@@ -218,6 +218,28 @@ class Trainer(object):
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
def train_with_config(self, config):
"""
An alias to simplify the use of `TrainConfig`.
It is equivalent to the following:
.. code-block:: python
self.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
"""
if config.data or config.dataflow or config.model:
logger.warn(
"data/dataflow/model in TrainConfig will not be used "
"in `Trainer.train_with_config`")
logger.warn("To build the graph from config, use `launch_train_with_config`!")
self.train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
......@@ -337,20 +359,6 @@ class SingleCostTrainer(TowerTrainer):
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks = callbacks + self._internal_callbacks
super(SingleCostTrainer, self).train(
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
......@@ -375,8 +383,10 @@ class SingleCostTrainer(TowerTrainer):
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
return self._internal_callbacks
internal_callbacks = input_callbacks + train_callbacks
for cb in internal_callbacks:
self._register_callback(cb)
return internal_callbacks
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment