Commit d200d632 authored by Yuxin Wu's avatar Yuxin Wu

Print warning when failed to enter the expected name scope (#1257)

parent 8c8de86c
......@@ -19,7 +19,7 @@ This is a minimal implementation that simply contains these files:
Data:
1. It's easy to train on your own data, by calling `DatasetRegistry.register(name, lambda: YourDatasetSplit())`,
and modify `cfg.DATA.*` accordingly.
and modify `cfg.DATA.*` accordingly. Afterwards, "name" can be used in `cfg.DATA.TRAIN`.
`YourDatasetSplit` can be:
......
......@@ -82,6 +82,16 @@ class GeneralizedRCNN(ModelDesc):
rpn_losses + head_losses + [wd_cost], 'total_cost')
add_moving_summary(total_cost, wd_cost)
return total_cost
else:
# Check that the model defines the tensors it declares for inference
# For existing models, they are defined in "fastrcnn_predictions(name_scope='output')"
G = tf.get_default_graph()
ns = G.get_name_scope()
for name in self.get_inference_tensor_names()[1]:
try:
G.get_tensor_by_name('/'.join([ns, name + ':0']))
except KeyError:
raise KeyError("Your model does not define the tensor '{}' in inference context.".format(name))
class ResNetC4Model(GeneralizedRCNN):
......
......@@ -7,6 +7,7 @@ from contextlib import contextmanager
from ..compat import tfv1 as tf
from ..utils.argtools import graph_memoized
from ..utils import logger
from .common import get_tf_version_tuple
__all__ = ['auto_reuse_variable_scope', 'cached_name_scope', 'under_name_scope']
......@@ -66,6 +67,9 @@ def under_name_scope(name_scope=None):
2. The 'name_scope' argument of the decorator.
3. (default) The name of the decorated function itself.
If the name is taken and cannot be used, a warning will be
printed in the first case.
Example:
.. code-block:: python
......@@ -86,10 +90,22 @@ def under_name_scope(name_scope=None):
def _impl(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
warn_incorrect_scope = 'name_scope' in kwargs
scopename = kwargs.pop('name_scope', name_scope)
if scopename is None:
scopename = func.__name__
if warn_incorrect_scope:
# cached_name_scope will try to reenter the existing scope
with cached_name_scope(scopename, top_level=False) as scope:
scope = scope.strip('/')
# but it can still conflict with an existing tensor
if not scope.endswith(scopename):
logger.warn(""" \
Calling function {} with name_scope='{}', but actual name scope becomes '{}'. \
The name '{}' might be taken.""".format(func.__name__, scopename, scope.split('/')[-1], scopename))
return func(*args, **kwargs)
else:
with tf.name_scope(scopename):
return func(*args, **kwargs)
return wrapper
......
#-*- coding: utf-8 -*-
#File:
import unittest
import tensorflow as tf
from ..utils import logger
from .scope_utils import under_name_scope
class ScopeUtilsTest(unittest.TestCase):
@under_name_scope(name_scope='s')
def _f(self, check=True):
if check:
assert tf.get_default_graph().get_name_scope().endswith('s')
return True
def test_under_name_scope(self):
self.assertTrue(self._f())
with self.assertRaises(AssertionError):
self._f() # name conflict
def test_under_name_scope_warning(self):
x = tf.placeholder(tf.float32, [3])
tf.nn.relu(x, name='s')
with self.assertLogs(logger=logger._logger, level='WARNING'):
self._f(check=False, name_scope='s')
if __name__ == '__main__':
unittest.main()
......@@ -15,6 +15,7 @@ python -c "import tensorflow as tf; tf.Operation._add_control_input"
# run tests
python -m tensorpack.callbacks.param_test
python -m tensorpack.tfutils.unit_tests
TENSORPACK_SERIALIZE=pyarrow python test_serializer.py
TENSORPACK_SERIALIZE=msgpack python test_serializer.py
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment