Commit d200d632 authored by Yuxin Wu's avatar Yuxin Wu

Print warning when failed to enter the expected name scope (#1257)

parent 8c8de86c
...@@ -19,7 +19,7 @@ This is a minimal implementation that simply contains these files: ...@@ -19,7 +19,7 @@ This is a minimal implementation that simply contains these files:
Data: Data:
1. It's easy to train on your own data, by calling `DatasetRegistry.register(name, lambda: YourDatasetSplit())`, 1. It's easy to train on your own data, by calling `DatasetRegistry.register(name, lambda: YourDatasetSplit())`,
and modify `cfg.DATA.*` accordingly. and modify `cfg.DATA.*` accordingly. Afterwards, "name" can be used in `cfg.DATA.TRAIN`.
`YourDatasetSplit` can be: `YourDatasetSplit` can be:
......
...@@ -82,6 +82,16 @@ class GeneralizedRCNN(ModelDesc): ...@@ -82,6 +82,16 @@ class GeneralizedRCNN(ModelDesc):
rpn_losses + head_losses + [wd_cost], 'total_cost') rpn_losses + head_losses + [wd_cost], 'total_cost')
add_moving_summary(total_cost, wd_cost) add_moving_summary(total_cost, wd_cost)
return total_cost return total_cost
else:
# Check that the model defines the tensors it declares for inference
# For existing models, they are defined in "fastrcnn_predictions(name_scope='output')"
G = tf.get_default_graph()
ns = G.get_name_scope()
for name in self.get_inference_tensor_names()[1]:
try:
G.get_tensor_by_name('/'.join([ns, name + ':0']))
except KeyError:
raise KeyError("Your model does not define the tensor '{}' in inference context.".format(name))
class ResNetC4Model(GeneralizedRCNN): class ResNetC4Model(GeneralizedRCNN):
......
...@@ -7,6 +7,7 @@ from contextlib import contextmanager ...@@ -7,6 +7,7 @@ from contextlib import contextmanager
from ..compat import tfv1 as tf from ..compat import tfv1 as tf
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from ..utils import logger
from .common import get_tf_version_tuple from .common import get_tf_version_tuple
__all__ = ['auto_reuse_variable_scope', 'cached_name_scope', 'under_name_scope'] __all__ = ['auto_reuse_variable_scope', 'cached_name_scope', 'under_name_scope']
...@@ -66,6 +67,9 @@ def under_name_scope(name_scope=None): ...@@ -66,6 +67,9 @@ def under_name_scope(name_scope=None):
2. The 'name_scope' argument of the decorator. 2. The 'name_scope' argument of the decorator.
3. (default) The name of the decorated function itself. 3. (default) The name of the decorated function itself.
If the name is taken and cannot be used, a warning will be
printed in the first case.
Example: Example:
.. code-block:: python .. code-block:: python
...@@ -86,10 +90,22 @@ def under_name_scope(name_scope=None): ...@@ -86,10 +90,22 @@ def under_name_scope(name_scope=None):
def _impl(func): def _impl(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
warn_incorrect_scope = 'name_scope' in kwargs
scopename = kwargs.pop('name_scope', name_scope) scopename = kwargs.pop('name_scope', name_scope)
if scopename is None: if scopename is None:
scopename = func.__name__ scopename = func.__name__
if warn_incorrect_scope:
# cached_name_scope will try to reenter the existing scope
with cached_name_scope(scopename, top_level=False) as scope:
scope = scope.strip('/')
# but it can still conflict with an existing tensor
if not scope.endswith(scopename):
logger.warn(""" \
Calling function {} with name_scope='{}', but actual name scope becomes '{}'. \
The name '{}' might be taken.""".format(func.__name__, scopename, scope.split('/')[-1], scopename))
return func(*args, **kwargs)
else:
with tf.name_scope(scopename): with tf.name_scope(scopename):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper return wrapper
......
#-*- coding: utf-8 -*-
#File:
import unittest
import tensorflow as tf
from ..utils import logger
from .scope_utils import under_name_scope
class ScopeUtilsTest(unittest.TestCase):
@under_name_scope(name_scope='s')
def _f(self, check=True):
if check:
assert tf.get_default_graph().get_name_scope().endswith('s')
return True
def test_under_name_scope(self):
self.assertTrue(self._f())
with self.assertRaises(AssertionError):
self._f() # name conflict
def test_under_name_scope_warning(self):
x = tf.placeholder(tf.float32, [3])
tf.nn.relu(x, name='s')
with self.assertLogs(logger=logger._logger, level='WARNING'):
self._f(check=False, name_scope='s')
if __name__ == '__main__':
unittest.main()
...@@ -15,6 +15,7 @@ python -c "import tensorflow as tf; tf.Operation._add_control_input" ...@@ -15,6 +15,7 @@ python -c "import tensorflow as tf; tf.Operation._add_control_input"
# run tests # run tests
python -m tensorpack.callbacks.param_test python -m tensorpack.callbacks.param_test
python -m tensorpack.tfutils.unit_tests
TENSORPACK_SERIALIZE=pyarrow python test_serializer.py TENSORPACK_SERIALIZE=pyarrow python test_serializer.py
TENSORPACK_SERIALIZE=msgpack python test_serializer.py TENSORPACK_SERIALIZE=msgpack python test_serializer.py
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment