Commit 4c82fb50 authored by Yuxin Wu's avatar Yuxin Wu

move WGANTrainer to GAN

parent 0763b16d
......@@ -71,6 +71,36 @@ class GANTrainer(FeedfreeTrainerBase):
self.train_op = self.d_min
class SplitGANTrainer(FeedfreeTrainerBase):
""" A new trainer which runs two optimization ops with a certain ratio. """
def __init__(self, config, d_interval=1):
"""
Args:
d_interval: will run d_opt only after this many of g_opt.
"""
self._input_method = QueueInput(config.dataflow)
self._d_interval = d_interval
super(SplitGANTrainer, self).__init__(config)
def _setup(self):
super(SplitGANTrainer, self)._setup()
self.build_train_tower()
opt = self.model.get_optimizer()
self.d_min = opt.minimize(
self.model.d_loss, var_list=self.model.d_vars, name='d_min')
self.g_min = opt.minimize(
self.model.g_loss, var_list=self.model.g_vars, name='g_min')
self._cnt = 0
def run_step(self):
self._cnt += 1
if self._cnt % (self._d_interval) == 0:
self.hooked_sess.run(self.d_min)
else:
self.hooked_sess.run(self.g_min)
class RandomZData(DataFlow):
def __init__(self, shape):
super(RandomZData, self).__init__()
......
......@@ -9,7 +9,7 @@ import argparse
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
import tensorflow as tf
from GAN import GANTrainer
from GAN import SplitGANTrainer
"""
Wasserstein-GAN.
......@@ -61,36 +61,11 @@ def get_config():
# use the same data in the DCGAN example
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
steps_per_epoch=300,
steps_per_epoch=500,
max_epoch=200,
)
class WGANTrainer(FeedfreeTrainerBase):
""" A new trainer which runs two optimization ops with 5:1 ratio.
This is to be consistent with the original code, but I found just
running them 1:1 (i.e. just using the existing GANTrainer) also works well.
"""
def __init__(self, config):
self._input_method = QueueInput(config.dataflow)
super(WGANTrainer, self).__init__(config)
def _setup(self):
super(WGANTrainer, self)._setup()
self.build_train_tower()
opt = self.model.get_optimizer()
self.d_min = opt.minimize(
self.model.d_loss, var_list=self.model.d_vars, name='d_min')
self.g_min = opt.minimize(
self.model.g_loss, var_list=self.model.g_vars, name='g_op')
def run_step(self):
for k in range(5):
self.hooked_sess.run(self.d_min)
self.hooked_sess.run(self.g_min)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load', help='load model')
......@@ -105,4 +80,8 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
WGANTrainer(config).train()
"""
This is to be consistent with the original code, but I found just
running them 1:1 (i.e. just using the existing GANTrainer) also works well.
"""
SplitGANTrainer(config, d_interval=5).train()
......@@ -130,9 +130,10 @@ def add_moving_summary(v, *args, **kwargs):
averager = tf.train.ExponentialMovingAverage(
decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(v)
for c in v:
# TODO do this in the EMA callback?
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
for c in v:
# TODO do this in the EMA callback?
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c))
tf.add_to_collection(coll, avg_maintain_op)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment