Commit f870a49c authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] allow FREEZE_AFFINE

parent 43d2bffb
...@@ -9,12 +9,25 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope ...@@ -9,12 +9,25 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.models import ( from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, BNReLU) Conv2D, MaxPooling, BatchNorm, BNReLU)
import config
def maybe_freeze_affine(getter, *args, **kwargs):
# custom getter to freeze affine params inside bn
name = args[0] if len(args) else kwargs.get('name')
if name.endswith('/gamma') or name.endswith('/beta'):
if config.FREEZE_AFFINE:
kwargs['trainable'] = False
return getter(*args, **kwargs)
@contextmanager @contextmanager
def resnet_argscope(): def resnet_argscope():
with argscope([Conv2D, MaxPooling, BatchNorm], data_format='NCHW'), \ with argscope([Conv2D, MaxPooling, BatchNorm], data_format='NCHW'), \
argscope(Conv2D, use_bias=False), \ argscope(Conv2D, use_bias=False), \
argscope(BatchNorm, use_local_stat=False): argscope(BatchNorm, use_local_stat=False), \
tf.variable_scope(tf.get_variable_scope(),
custom_getter=maybe_freeze_affine):
yield yield
......
...@@ -17,6 +17,7 @@ CLASS_NAMES = [] # NUM_CLASS strings. Will be populated later by coco loader ...@@ -17,6 +17,7 @@ CLASS_NAMES = [] # NUM_CLASS strings. Will be populated later by coco loader
# basemodel ---------------------- # basemodel ----------------------
RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50 RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101 # RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101
FREEZE_AFFINE = False # do not train affine parameters inside BN
# schedule ----------------------- # schedule -----------------------
BASE_LR = 1e-2 BASE_LR = 1e-2
......
...@@ -201,7 +201,7 @@ def BatchNorm(x, use_local_stat=None, momentum=0.9, epsilon=1e-5, ...@@ -201,7 +201,7 @@ def BatchNorm(x, use_local_stat=None, momentum=0.9, epsilon=1e-5,
'decay': 'momentum' 'decay': 'momentum'
}) })
def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5, def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
scale=True, bias=True, gamma_initializer=None, scale=True, center=True, gamma_initializer=None,
data_format='channels_last'): data_format='channels_last'):
""" """
Batch Renormalization layer, as described in the paper: Batch Renormalization layer, as described in the paper:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment