Commit 696c7830 authored by Yuxin Wu's avatar Yuxin Wu

use tf.get_name_scope to replace get_name_scope_name

parent 44f603c0
import tensorflow as tf import tensorflow as tf
from functools import wraps from functools import wraps
import numpy as np import numpy as np
from .scope_utils import get_name_scope_name
from .common import get_tf_version_number
__all__ = ['Distribution', __all__ = ['Distribution',
'CategoricalDistribution', 'GaussianDistribution', 'CategoricalDistribution', 'GaussianDistribution',
...@@ -18,6 +19,16 @@ def class_scope(func): ...@@ -18,6 +19,16 @@ def class_scope(func):
``tf.name_scope(...)`` in each method. ``tf.name_scope(...)`` in each method.
""" """
def get_name_scope_name():
if get_tf_version_number() > 1.2:
return tf.get_name_scope().name
else:
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
@wraps(func) @wraps(func)
def _impl(self, *args, **kwargs): def _impl(self, *args, **kwargs):
# is there a specific name? # is there a specific name?
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
import tensorflow as tf import tensorflow as tf
import six import six
from .common import get_tf_version_number
from ..utils.develop import deprecated
if six.PY2: if six.PY2:
import functools32 as functools import functools32 as functools
else: else:
...@@ -13,16 +15,20 @@ else: ...@@ -13,16 +15,20 @@ else:
__all__ = ['get_name_scope_name', 'auto_reuse_variable_scope'] __all__ = ['get_name_scope_name', 'auto_reuse_variable_scope']
@deprecated("Use tf.get_name_scope() (available since 1.2.1).")
def get_name_scope_name(): def get_name_scope_name():
""" """
Returns: Returns:
str: the name of the current name scope, without the ending '/'. str: the name of the current name scope, without the ending '/'.
""" """
g = tf.get_default_graph() if get_tf_version_number() > 1.2:
s = "RANDOM_STR_ABCDEFG" return tf.get_name_scope().name
unique = g.unique_name(s) else:
scope = unique[:-len(s)].rstrip('/') g = tf.get_default_graph()
return scope s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
def auto_reuse_variable_scope(func): def auto_reuse_variable_scope(func):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment