Commit 2770ede8 authored by Yuxin Wu's avatar Yuxin Wu

Fix name scope after using tflayers (#627)

parent 4adbaa94
...@@ -7,7 +7,7 @@ from contextlib import contextmanager ...@@ -7,7 +7,7 @@ from contextlib import contextmanager
import operator import operator
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import get_tf_version_number from ..tfutils.varreplace import custom_getter_scope
__all__ = ['LeastLoadedDeviceSetter', __all__ = ['LeastLoadedDeviceSetter',
...@@ -42,22 +42,8 @@ def override_to_local_variable(enable=True): ...@@ -42,22 +42,8 @@ def override_to_local_variable(enable=True):
_replace_global_by_local(kwargs) _replace_global_by_local(kwargs)
return getter(name, *args, **kwargs) return getter(name, *args, **kwargs)
orig_vs = tf.get_variable_scope() with custom_getter_scope(custom_getter):
if get_tf_version_number() >= 1.5: yield
with tf.variable_scope(
orig_vs,
custom_getter=custom_getter,
auxiliary_name_scope=False):
yield
else:
if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope()
else:
ns = orig_vs.original_name_scope
with tf.variable_scope(
orig_vs, custom_getter=custom_getter):
with tf.name_scope(ns + '/' if ns else ''):
yield
else: else:
yield yield
......
...@@ -6,14 +6,25 @@ ...@@ -6,14 +6,25 @@
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from .common import get_tf_version_number
__all__ = ['freeze_variables', 'remap_variables'] __all__ = ['freeze_variables', 'remap_variables']
@contextmanager @contextmanager
def custom_getter_scope(custom_getter): def custom_getter_scope(custom_getter):
scope = tf.get_variable_scope() scope = tf.get_variable_scope()
with tf.variable_scope(scope, custom_getter=custom_getter): if get_tf_version_number() >= 1.5:
yield with tf.variable_scope(
scope, custom_getter=custom_getter,
auxiliary_name_scope=False):
yield
else:
ns = tf.get_default_graph().get_name_scope()
with tf.variable_scope(
scope, custom_getter=custom_getter):
with tf.name_scope(ns + '/' if ns else ''):
yield
def remap_variables(fn): def remap_variables(fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment