Commit 9f056711 authored by Yuxin Wu's avatar Yuxin Wu

strict LinearWrap

parent c24697ee
......@@ -5,7 +5,6 @@
import six
from types import ModuleType
from ..utils import logger
from .common import get_registered_layer
__all__ = ['LinearWrap']
......@@ -58,12 +57,11 @@ class LinearWrap(object):
return LinearWrap(ret)
return f
else:
if layer_name != 'tf':
logger.warn(
"You're calling LinearWrap.__getattr__ with {}:"
" neither a layer nor 'tf'!".format(layer_name))
import tensorflow as tf # noqa
layer = eval(layer_name)
assert layer_name == 'tf', \
"Calling LinearWrap.{}:" \
" neither a layer nor 'tf'! " \
"Did you forget to extract tensor from LinearWrap?".format(layer_name)
import tensorflow as layer # noqa
assert isinstance(layer, ModuleType), layer
return LinearWrap._TFModuleFunc(layer, self._t)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment