Commit 196a17f3 authored by Yuxin Wu's avatar Yuxin Wu

Fix some more scope issues

parent 2770ede8
...@@ -6,6 +6,7 @@ from contextlib import contextmanager ...@@ -6,6 +6,7 @@ from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from tensorpack.tfutils.argscope import argscope, get_arg_scope from tensorpack.tfutils.argscope import argscope, get_arg_scope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.varreplace import custom_getter_scope
from tensorpack.models import ( from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, BNReLU) Conv2D, MaxPooling, BatchNorm, BNReLU)
...@@ -26,8 +27,7 @@ def resnet_argscope(): ...@@ -26,8 +27,7 @@ def resnet_argscope():
with argscope([Conv2D, MaxPooling, BatchNorm], data_format='NCHW'), \ with argscope([Conv2D, MaxPooling, BatchNorm], data_format='NCHW'), \
argscope(Conv2D, use_bias=False), \ argscope(Conv2D, use_bias=False), \
argscope(BatchNorm, use_local_stat=False), \ argscope(BatchNorm, use_local_stat=False), \
tf.variable_scope(tf.get_variable_scope(), custom_getter_scope(maybe_freeze_affine):
custom_getter=maybe_freeze_affine):
yield yield
......
...@@ -503,7 +503,8 @@ def maskrcnn_head(feature, num_class): ...@@ -503,7 +503,8 @@ def maskrcnn_head(feature, num_class):
""" """
with argscope([Conv2D, Deconv2D], data_format='NCHW', with argscope([Conv2D, Deconv2D], data_format='NCHW',
W_init=tf.variance_scaling_initializer( W_init=tf.variance_scaling_initializer(
scale=2.0, mode='fan_in', distribution='normal')): scale=2.0, mode='fan_out', distribution='normal')):
# c2's MSRAFill is fan_out
l = Deconv2D('deconv', feature, 256, 2, stride=2, nl=tf.nn.relu) l = Deconv2D('deconv', feature, 256, 2, stride=2, nl=tf.nn.relu)
l = Conv2D('conv', l, num_class - 1, 1) l = Conv2D('conv', l, num_class - 1, 1)
return l return l
......
...@@ -8,6 +8,7 @@ import functools ...@@ -8,6 +8,7 @@ import functools
from contextlib import contextmanager from contextlib import contextmanager
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from .common import get_tf_version_number
__all__ = ['auto_reuse_variable_scope', 'cached_name_scope', 'under_name_scope'] __all__ = ['auto_reuse_variable_scope', 'cached_name_scope', 'under_name_scope']
...@@ -39,8 +40,14 @@ def auto_reuse_variable_scope(func): ...@@ -39,8 +40,14 @@ def auto_reuse_variable_scope(func):
h = hash((tf.get_default_graph(), scope.name)) h = hash((tf.get_default_graph(), scope.name))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope)) # print("Entering " + scope.name + " reuse: " + str(h in used_scope))
if h in used_scope: if h in used_scope:
with tf.variable_scope(scope, reuse=True): if get_tf_version_number() >= 1.5:
return func(*args, **kwargs) with tf.variable_scope(scope, reuse=True, auxiliary_name_scope=False):
return func(*args, **kwargs)
else:
ns = tf.get_default_graph().get_name_scope()
with tf.variable_scope(scope, reuse=True), \
tf.name_scope(ns + '/' if ns else ''):
return func(*args, **kwargs)
else: else:
used_scope.add(h) used_scope.add(h)
return func(*args, **kwargs) return func(*args, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment