Commit 7f2c708e authored by Yuxin Wu's avatar Yuxin Wu

add instancenorm

parent bafe8337
...@@ -47,14 +47,14 @@ TrainConfig( ...@@ -47,14 +47,14 @@ TrainConfig(
# run `tf.summary.merge_all` every epoch and send results to monitors # run `tf.summary.merge_all` every epoch and send results to monitors
MergeAllSummaries(), MergeAllSummaries(),
], ],
monitors=[ # monitors are a special kind of callbacks. these are also enabled by default monitors=[ # monitors are a special kind of callbacks. these are also enabled by default
# write all monitor data to tensorboard # write all monitor data to tensorboard
TFSummaryWriter(), TFSummaryWriter(),
# write all scalar data to a json file, for easy parsing # write all scalar data to a json file, for easy parsing
JSONWriter(), JSONWriter(),
# print all scalar data every epoch (can be configured differently) # print all scalar data every epoch (can be configured differently)
ScalarPrinter(), ScalarPrinter(),
] ]
) )
``` ```
......
...@@ -36,9 +36,9 @@ BATCH = 64 ...@@ -36,9 +36,9 @@ BATCH = 64
NF = 64 # channel size NF = 64 # channel size
def BNLReLU(x, name): def BNLReLU(x, name=None):
x = BatchNorm('bn', x) x = BatchNorm('bn', x)
return LeakyReLU(x) return LeakyReLU(x, name=name)
class Model(GANModelDesc): class Model(GANModelDesc):
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
import tensorflow as tf import tensorflow as tf
from .common import layer_register from .common import layer_register
__all__ = ['LayerNorm', 'InstanceNorm']
@layer_register(log_shape=False) @layer_register(log_shape=False)
def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'): def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'):
...@@ -45,3 +47,40 @@ def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC' ...@@ -45,3 +47,40 @@ def LayerNorm(x, epsilon=1e-5, use_bias=True, use_scale=True, data_format='NHWC'
gamma = tf.ones([1] * ndims, name='gamma') gamma = tf.ones([1] * ndims, name='gamma')
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
@layer_register(log_shape=False)
def InstanceNorm(x, epsilon=1e-5, data_format='NHWC', use_affine=True):
"""
Instance Normalization, as in the paper:
`Instance Normalization: The Missing Ingredient for Fast Stylization
<https://arxiv.org/abs/1607.08022>`_.
Args:
x (tf.Tensor): a 4D tensor.
epsilon (float): avoid divide-by-zero
use_affine (bool): whether to apply learnable affine transformation
"""
shape = x.get_shape().as_list()
assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"
if data_format == 'NHWC':
axis = [1, 2]
ch = shape[3]
new_shape = [1, 1, 1, ch]
else:
axis = [2, 3]
ch = shape[1]
new_shape = [1, ch, 1, 1]
assert ch is not None, "Input of InstanceNorm require known channel!"
mean, var = tf.nn.moments(x, axis, keep_dims=True)
if not use_affine:
return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output')
beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
gamma = tf.get_variable('gamma', [ch], initializer=tf.constant_initializer(1.0))
gamma = tf.reshape(gamma, new_shape)
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment