Commit 1cd536a9 authored by Yuxin Wu's avatar Yuxin Wu

disable autograph in prelu as well

parent 0c53df74
...@@ -15,6 +15,7 @@ from ..utils.argtools import get_data_format ...@@ -15,6 +15,7 @@ from ..utils.argtools import get_data_format
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from .common import VariableHolder, layer_register from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable from .tflayer import convert_to_tflayer_args, rename_get_variable
from .utils import disable_autograph
__all__ = ['BatchNorm', 'BatchRenorm'] __all__ = ['BatchNorm', 'BatchRenorm']
...@@ -60,15 +61,6 @@ def internal_update_bn_ema(xn, batch_mean, batch_var, ...@@ -60,15 +61,6 @@ def internal_update_bn_ema(xn, batch_mean, batch_var,
return tf.identity(xn, name='output') return tf.identity(xn, name='output')
try:
# When BN is used as an activation, keras layers try to autograph.convert it
# This leads to massive warnings so we disable it.
from tensorflow.python.autograph.impl.api import do_not_convert as disable_autograph
except ImportError:
def disable_autograph():
return lambda x: x
@layer_register() @layer_register()
@convert_to_tflayer_args( @convert_to_tflayer_args(
args_names=[], args_names=[],
......
...@@ -8,6 +8,7 @@ from ..utils.develop import log_deprecated ...@@ -8,6 +8,7 @@ from ..utils.develop import log_deprecated
from ..compat import tfv1 from ..compat import tfv1
from .batch_norm import BatchNorm from .batch_norm import BatchNorm
from .common import VariableHolder, layer_register from .common import VariableHolder, layer_register
from .utils import disable_autograph
__all__ = ['Maxout', 'PReLU', 'BNReLU'] __all__ = ['Maxout', 'PReLU', 'BNReLU']
...@@ -37,6 +38,7 @@ def Maxout(x, num_unit): ...@@ -37,6 +38,7 @@ def Maxout(x, num_unit):
@layer_register() @layer_register()
@disable_autograph()
def PReLU(x, init=0.001, name=None): def PReLU(x, init=0.001, name=None):
""" """
Parameterized ReLU as in the paper `Delving Deep into Rectifiers: Surpassing Parameterized ReLU as in the paper `Delving Deep into Rectifiers: Surpassing
......
...@@ -35,3 +35,12 @@ class VariableHolder(object): ...@@ -35,3 +35,12 @@ class VariableHolder(object):
list of all variables list of all variables
""" """
return list(six.itervalues(self._vars)) return list(six.itervalues(self._vars))
try:
# When BN is used as an activation, keras layers try to autograph.convert it
# This leads to massive warnings so we disable it.
from tensorflow.python.autograph.impl.api import do_not_convert as disable_autograph
except ImportError:
def disable_autograph():
return lambda x: x
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment