Commit 8868a9e4 authored by Yuxin Wu's avatar Yuxin Wu

add gamma_init in BN to imitate torch

parent 4c82fb50
...@@ -71,8 +71,8 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -71,8 +71,8 @@ class GANTrainer(FeedfreeTrainerBase):
self.train_op = self.d_min self.train_op = self.d_min
class SplitGANTrainer(FeedfreeTrainerBase): class SeparateGANTrainer(FeedfreeTrainerBase):
""" A new trainer which runs two optimization ops with a certain ratio. """ """ A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
def __init__(self, config, d_interval=1): def __init__(self, config, d_interval=1):
""" """
Args: Args:
...@@ -80,10 +80,10 @@ class SplitGANTrainer(FeedfreeTrainerBase): ...@@ -80,10 +80,10 @@ class SplitGANTrainer(FeedfreeTrainerBase):
""" """
self._input_method = QueueInput(config.dataflow) self._input_method = QueueInput(config.dataflow)
self._d_interval = d_interval self._d_interval = d_interval
super(SplitGANTrainer, self).__init__(config) super(SeparateGANTrainer, self).__init__(config)
def _setup(self): def _setup(self):
super(SplitGANTrainer, self)._setup() super(SeparateGANTrainer, self)._setup()
self.build_train_tower() self.build_train_tower()
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
...@@ -94,11 +94,11 @@ class SplitGANTrainer(FeedfreeTrainerBase): ...@@ -94,11 +94,11 @@ class SplitGANTrainer(FeedfreeTrainerBase):
self._cnt = 0 self._cnt = 0
def run_step(self): def run_step(self):
self._cnt += 1 if self._cnt % (self._d_interval + 1) == 0:
if self._cnt % (self._d_interval) == 0:
self.hooked_sess.run(self.d_min) self.hooked_sess.run(self.d_min)
else: else:
self.hooked_sess.run(self.g_min) self.hooked_sess.run(self.g_min)
self._cnt += 1
class RandomZData(DataFlow): class RandomZData(DataFlow):
......
...@@ -205,14 +205,13 @@ if __name__ == '__main__': ...@@ -205,14 +205,13 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling') parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='Image directory') parser.add_argument('--data', help='Image directory', required=True)
parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB') parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
parser.add_argument('-b', '--batch', type=int, default=1) parser.add_argument('-b', '--batch', type=int, default=1)
global args global args
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
assert args.data
BATCH = args.batch BATCH = args.batch
......
...@@ -9,7 +9,7 @@ import argparse ...@@ -9,7 +9,7 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
import tensorflow as tf import tensorflow as tf
from GAN import SplitGANTrainer from GAN import SeparateGANTrainer
""" """
Wasserstein-GAN. Wasserstein-GAN.
...@@ -84,4 +84,4 @@ if __name__ == '__main__': ...@@ -84,4 +84,4 @@ if __name__ == '__main__':
This is to be consistent with the original code, but I found just This is to be consistent with the original code, but I found just
running them 1:1 (i.e. just using the existing GANTrainer) also works well. running them 1:1 (i.e. just using the existing GANTrainer) also works well.
""" """
SplitGANTrainer(config, d_interval=5).train() SeparateGANTrainer(config, d_interval=5).train()
...@@ -96,13 +96,13 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -96,13 +96,13 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x, ema_mean, ema_var, beta, gamma, epsilon, 'output') x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
def get_bn_variables(n_out, use_scale, use_bias): def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
if use_bias: if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer()) beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
else: else:
beta = tf.zeros([n_out], name='beta') beta = tf.zeros([n_out], name='beta')
if use_scale: if use_scale:
gamma = tf.get_variable('gamma', [n_out], initializer=tf.constant_initializer(1.0)) gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
else: else:
gamma = tf.ones([n_out], name='gamma') gamma = tf.ones([n_out], name='gamma')
# x * gamma + beta # x * gamma + beta
...@@ -132,7 +132,8 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay): ...@@ -132,7 +132,8 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, data_format='NHWC'): use_scale=True, use_bias=True,
gamma_init=tf.constant_initializer(1.0), data_format='NHWC'):
""" """
Batch Normalization layer, as described in the paper: Batch Normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by `Batch Normalization: Accelerating Deep Network Training by
...@@ -145,14 +146,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -145,14 +146,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
decay (float): decay rate of moving average. decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero. epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not. use_scale, use_bias (bool): whether to use the extra affine transformation or not.
gamma_init: initializer for gamma (the scale).
Returns: Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x. tf.Tensor: a tensor named ``output`` with the same shape of x.
Variable Names: Variable Names:
* ``beta``: the bias term. * ``beta``: the bias term. Will be zero-inited by default.
* ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``. * ``gamma``: the scale term. Will be one-inited by default.
Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean. * ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance. * ``variance/EMA``: the moving average of variance.
...@@ -176,7 +179,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -176,7 +179,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x = tf.reshape(x, [-1, 1, 1, n_out]) x = tf.reshape(x, [-1, 1, 1, n_out])
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!" assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias) beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias, gamma_init)
ctx = get_current_tower_context() ctx = get_current_tower_context()
if use_local_stat is None: if use_local_stat is None:
...@@ -245,7 +248,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -245,7 +248,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
n_out = shape[-1] n_out = shape[-1]
if len(shape) == 2: if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out]) x = tf.reshape(x, [-1, 1, 1, n_out])
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias) beta, gamma, moving_mean, moving_var = get_bn_variables(
n_out, use_scale, use_bias, tf.constant_initializer(1.0))
ctx = get_current_tower_context() ctx = get_current_tower_context()
use_local_stat = ctx.is_training use_local_stat = ctx.is_training
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment