Commit 8868a9e4 authored by Yuxin Wu's avatar Yuxin Wu

add gamma_init in BN to imitate torch

parent 4c82fb50
......@@ -71,8 +71,8 @@ class GANTrainer(FeedfreeTrainerBase):
self.train_op = self.d_min
class SplitGANTrainer(FeedfreeTrainerBase):
""" A new trainer which runs two optimization ops with a certain ratio. """
class SeparateGANTrainer(FeedfreeTrainerBase):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
def __init__(self, config, d_interval=1):
"""
Args:
......@@ -80,10 +80,10 @@ class SplitGANTrainer(FeedfreeTrainerBase):
"""
self._input_method = QueueInput(config.dataflow)
self._d_interval = d_interval
super(SplitGANTrainer, self).__init__(config)
super(SeparateGANTrainer, self).__init__(config)
def _setup(self):
super(SplitGANTrainer, self)._setup()
super(SeparateGANTrainer, self)._setup()
self.build_train_tower()
opt = self.model.get_optimizer()
......@@ -94,11 +94,11 @@ class SplitGANTrainer(FeedfreeTrainerBase):
self._cnt = 0
def run_step(self):
self._cnt += 1
if self._cnt % (self._d_interval) == 0:
if self._cnt % (self._d_interval + 1) == 0:
self.hooked_sess.run(self.d_min)
else:
self.hooked_sess.run(self.g_min)
self._cnt += 1
class RandomZData(DataFlow):
......
......@@ -205,14 +205,13 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='Image directory')
parser.add_argument('--data', help='Image directory', required=True)
parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
parser.add_argument('-b', '--batch', type=int, default=1)
global args
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
assert args.data
BATCH = args.batch
......
......@@ -9,7 +9,7 @@ import argparse
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
import tensorflow as tf
from GAN import SplitGANTrainer
from GAN import SeparateGANTrainer
"""
Wasserstein-GAN.
......@@ -84,4 +84,4 @@ if __name__ == '__main__':
This is to be consistent with the original code, but I found just
running them 1:1 (i.e. just using the existing GANTrainer) also works well.
"""
SplitGANTrainer(config, d_interval=5).train()
SeparateGANTrainer(config, d_interval=5).train()
......@@ -96,13 +96,13 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
x, ema_mean, ema_var, beta, gamma, epsilon, 'output')
def get_bn_variables(n_out, use_scale, use_bias):
def get_bn_variables(n_out, use_scale, use_bias, gamma_init):
if use_bias:
beta = tf.get_variable('beta', [n_out], initializer=tf.constant_initializer())
else:
beta = tf.zeros([n_out], name='beta')
if use_scale:
gamma = tf.get_variable('gamma', [n_out], initializer=tf.constant_initializer(1.0))
gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
else:
gamma = tf.ones([n_out], name='gamma')
# x * gamma + beta
......@@ -132,7 +132,8 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
@layer_register(log_shape=False)
def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
use_scale=True, use_bias=True, data_format='NHWC'):
use_scale=True, use_bias=True,
gamma_init=tf.constant_initializer(1.0), data_format='NHWC'):
"""
Batch Normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by
......@@ -145,14 +146,16 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
use_scale, use_bias (bool): whether to use the extra affine transformation or not.
gamma_init: initializer for gamma (the scale).
Returns:
tf.Tensor: a tensor named ``output`` with the same shape of x.
Variable Names:
* ``beta``: the bias term.
* ``gamma``: the scale term. Input will be transformed by ``x * gamma + beta``.
* ``beta``: the bias term. Will be zero-inited by default.
* ``gamma``: the scale term. Will be one-inited by default.
Input will be transformed by ``x * gamma + beta``.
* ``mean/EMA``: the moving average of mean.
* ``variance/EMA``: the moving average of variance.
......@@ -176,7 +179,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
x = tf.reshape(x, [-1, 1, 1, n_out])
assert n_out is not None, "Input to BatchNorm cannot have unknown channels!"
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias)
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias, gamma_init)
ctx = get_current_tower_context()
if use_local_stat is None:
......@@ -245,7 +248,8 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
n_out = shape[-1]
if len(shape) == 2:
x = tf.reshape(x, [-1, 1, 1, n_out])
beta, gamma, moving_mean, moving_var = get_bn_variables(n_out, use_scale, use_bias)
beta, gamma, moving_mean, moving_var = get_bn_variables(
n_out, use_scale, use_bias, tf.constant_initializer(1.0))
ctx = get_current_tower_context()
use_local_stat = ctx.is_training
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment