Commit 696c7830 authored by Yuxin Wu's avatar Yuxin Wu

use tf.get_name_scope to replace get_name_scope_name

parent 44f603c0
import tensorflow as tf
from functools import wraps
import numpy as np
from .scope_utils import get_name_scope_name
from .common import get_tf_version_number
__all__ = ['Distribution',
'CategoricalDistribution', 'GaussianDistribution',
......@@ -18,6 +19,16 @@ def class_scope(func):
``tf.name_scope(...)`` in each method.
"""
def get_name_scope_name():
if get_tf_version_number() > 1.2:
return tf.get_name_scope().name
else:
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
@wraps(func)
def _impl(self, *args, **kwargs):
# is there a specific name?
......
......@@ -5,6 +5,8 @@
import tensorflow as tf
import six
from .common import get_tf_version_number
from ..utils.develop import deprecated
if six.PY2:
import functools32 as functools
else:
......@@ -13,16 +15,20 @@ else:
__all__ = ['get_name_scope_name', 'auto_reuse_variable_scope']
@deprecated("Use tf.get_name_scope() (available since 1.2.1).")
def get_name_scope_name():
"""
Returns:
str: the name of the current name scope, without the ending '/'.
"""
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
if get_tf_version_number() > 1.2:
return tf.get_name_scope().name
else:
g = tf.get_default_graph()
s = "RANDOM_STR_ABCDEFG"
unique = g.unique_name(s)
scope = unique[:-len(s)].rstrip('/')
return scope
def auto_reuse_variable_scope(func):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment