Commit 70c9ba8f authored by Yuxin Wu's avatar Yuxin Wu

VariableHolder for layer_norm; Do not assert replicated variables.

parent bf4d8938
...@@ -256,8 +256,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -256,8 +256,10 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
logger.error("[SyncMultiGPUReplicatedBuilder] variable " logger.error("[SyncMultiGPUReplicatedBuilder] variable "
"{} has its prefix {} appears multiple times in its name!".format(v.name, prefix)) "{} has its prefix {} appears multiple times in its name!".format(v.name, prefix))
copy_from = var_by_name.get(realname) copy_from = var_by_name.get(realname)
assert copy_from is not None, var_by_name.keys() if copy_from is not None:
post_init_ops.append(v.assign(copy_from.read_value())) post_init_ops.append(v.assign(copy_from.read_value()))
else:
logger.warn("[ReplicatedTrainer] Cannot find {} in the graph!".format(realname))
logger.info( logger.info(
"'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops))) "'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops)))
return tf.group(*post_init_ops, name='sync_variables_from_main_tower') return tf.group(*post_init_ops, name='sync_variables_from_main_tower')
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import tensorflow as tf import tensorflow as tf
from .common import layer_register from .common import layer_register, VariableHolder
__all__ = ['LayerNorm', 'InstanceNorm'] __all__ = ['LayerNorm', 'InstanceNorm']
...@@ -51,7 +51,14 @@ def LayerNorm( ...@@ -51,7 +51,14 @@ def LayerNorm(
else: else:
gamma = tf.ones([1] * ndims, name='gamma') gamma = tf.ones([1] * ndims, name='gamma')
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
vh = ret.variables = VariableHolder()
if use_scale:
vh.gamma = gamma
if use_bias:
vh.beta = beta
return ret
@layer_register() @layer_register()
...@@ -90,4 +97,10 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format= ...@@ -90,4 +97,10 @@ def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format=
gamma_init = tf.constant_initializer(1.0) gamma_init = tf.constant_initializer(1.0)
gamma = tf.get_variable('gamma', [ch], initializer=gamma_init) gamma = tf.get_variable('gamma', [ch], initializer=gamma_init)
gamma = tf.reshape(gamma, new_shape) gamma = tf.reshape(gamma, new_shape)
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')
vh = ret.variables = VariableHolder()
if use_affine:
vh.gamma = gamma
vh.beta = beta
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment