Commit 2770ede8 authored by Yuxin Wu's avatar Yuxin Wu

Fix name scope after using tflayers (#627)

parent 4adbaa94
......@@ -7,7 +7,7 @@ from contextlib import contextmanager
import operator
import tensorflow as tf
from ..tfutils.common import get_tf_version_number
from ..tfutils.varreplace import custom_getter_scope
__all__ = ['LeastLoadedDeviceSetter',
......@@ -42,22 +42,8 @@ def override_to_local_variable(enable=True):
_replace_global_by_local(kwargs)
return getter(name, *args, **kwargs)
orig_vs = tf.get_variable_scope()
if get_tf_version_number() >= 1.5:
with tf.variable_scope(
orig_vs,
custom_getter=custom_getter,
auxiliary_name_scope=False):
yield
else:
if get_tf_version_number() >= 1.2:
ns = tf.get_default_graph().get_name_scope()
else:
ns = orig_vs.original_name_scope
with tf.variable_scope(
orig_vs, custom_getter=custom_getter):
with tf.name_scope(ns + '/' if ns else ''):
yield
with custom_getter_scope(custom_getter):
yield
else:
yield
......
......@@ -6,14 +6,25 @@
import tensorflow as tf
from contextlib import contextmanager
from .common import get_tf_version_number
__all__ = ['freeze_variables', 'remap_variables']
@contextmanager
def custom_getter_scope(custom_getter):
scope = tf.get_variable_scope()
with tf.variable_scope(scope, custom_getter=custom_getter):
yield
if get_tf_version_number() >= 1.5:
with tf.variable_scope(
scope, custom_getter=custom_getter,
auxiliary_name_scope=False):
yield
else:
ns = tf.get_default_graph().get_name_scope()
with tf.variable_scope(
scope, custom_getter=custom_getter):
with tf.name_scope(ns + '/' if ns else ''):
yield
def remap_variables(fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment