Commit 6b0a1a70 authored by Yuxin Wu's avatar Yuxin Wu

add noise_shape in Dropout layer. (fix #244)

parent 7f2c708e
...@@ -53,7 +53,7 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -53,7 +53,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
@layer_register(log_shape=False, use_scope=False) @layer_register(log_shape=False, use_scope=False)
def Dropout(x, keep_prob=0.5, is_training=None): def Dropout(x, keep_prob=0.5, is_training=None, noise_shape=None):
""" """
Dropout layer as in the paper `Dropout: a Simple Way to Prevent Dropout layer as in the paper `Dropout: a Simple Way to Prevent
Neural Networks from Overfitting <http://dl.acm.org/citation.cfm?id=2670313>`_. Neural Networks from Overfitting <http://dl.acm.org/citation.cfm?id=2670313>`_.
...@@ -63,8 +63,9 @@ def Dropout(x, keep_prob=0.5, is_training=None): ...@@ -63,8 +63,9 @@ def Dropout(x, keep_prob=0.5, is_training=None):
when is_training=True. when is_training=True.
is_training (bool): If None, will use the current :class:`tensorpack.tfutils.TowerContext` is_training (bool): If None, will use the current :class:`tensorpack.tfutils.TowerContext`
to figure out. to figure out.
noise_shape: same as `tf.nn.dropout`.
""" """
if is_training is None: if is_training is None:
is_training = get_current_tower_context().is_training is_training = get_current_tower_context().is_training
keep_prob = tf.constant(keep_prob if is_training else 1.0) keep_prob = tf.constant(keep_prob if is_training else 1.0)
return tf.nn.dropout(x, keep_prob) return tf.nn.dropout(x, keep_prob, noise_shape=noise_shape)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment