Commit 125df71e authored by Yuxin Wu's avatar Yuxin Wu

allow GAN schedule both G and D

parent 6d9d89a1
...@@ -226,4 +226,4 @@ if __name__ == '__main__': ...@@ -226,4 +226,4 @@ if __name__ == '__main__':
) )
# train 1 D after 2 G # train 1 D after 2 G
SeparateGANTrainer(config, 2).train() SeparateGANTrainer(config, d_period=3).train()
...@@ -73,13 +73,16 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -73,13 +73,16 @@ class GANTrainer(FeedfreeTrainerBase):
class SeparateGANTrainer(FeedfreeTrainerBase): class SeparateGANTrainer(FeedfreeTrainerBase):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """ """ A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
def __init__(self, config, d_interval=1): def __init__(self, config, d_period=1, g_period=1):
""" """
Args: Args:
d_interval: will run d_opt only after this many of g_opt. d_period(int): period of each d_opt run
g_period(int): period of each g_opt run
""" """
self._input_method = QueueInput(config.dataflow) self._input_method = QueueInput(config.dataflow)
self._d_interval = d_interval self._d_period = int(d_period)
self._g_period = int(g_period)
assert min(d_period, g_period) == 1
super(SeparateGANTrainer, self).__init__(config) super(SeparateGANTrainer, self).__init__(config)
def _setup(self): def _setup(self):
...@@ -91,12 +94,12 @@ class SeparateGANTrainer(FeedfreeTrainerBase): ...@@ -91,12 +94,12 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
self.model.d_loss, var_list=self.model.d_vars, name='d_min') self.model.d_loss, var_list=self.model.d_vars, name='d_min')
self.g_min = opt.minimize( self.g_min = opt.minimize(
self.model.g_loss, var_list=self.model.g_vars, name='g_min') self.model.g_loss, var_list=self.model.g_vars, name='g_min')
self._cnt = 0 self._cnt = 1
def run_step(self): def run_step(self):
if self._cnt % (self._d_interval + 1) == 0: if self._cnt % (self._d_period) == 0:
self.hooked_sess.run(self.d_min) self.hooked_sess.run(self.d_min)
else: if self._cnt % (self._g_period) == 0:
self.hooked_sess.run(self.g_min) self.hooked_sess.run(self.g_min)
self._cnt += 1 self._cnt += 1
......
...@@ -81,7 +81,6 @@ if __name__ == '__main__': ...@@ -81,7 +81,6 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
""" """
This is to be consistent with the original code, but I found just The original code uses a different schedule.
running them 1:1 (i.e. just using the existing GANTrainer) also works well.
""" """
SeparateGANTrainer(config, d_interval=5).train() SeparateGANTrainer(config, d_period=3).train()
...@@ -26,6 +26,10 @@ def get_name_scope_name(): ...@@ -26,6 +26,10 @@ def get_name_scope_name():
def auto_reuse_variable_scope(func): def auto_reuse_variable_scope(func):
"""
A decorator which automatically reuse the current variable scope if the
function has been called with the same variable scope before.
"""
used_scope = set() used_scope = set()
@functools.wraps(func) @functools.wraps(func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment