Commit f870a49c authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] allow FREEZE_AFFINE

parent 43d2bffb
......@@ -9,12 +9,25 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, BNReLU)
import config
def maybe_freeze_affine(getter, *args, **kwargs):
# custom getter to freeze affine params inside bn
name = args[0] if len(args) else kwargs.get('name')
if name.endswith('/gamma') or name.endswith('/beta'):
if config.FREEZE_AFFINE:
kwargs['trainable'] = False
return getter(*args, **kwargs)
@contextmanager
def resnet_argscope():
with argscope([Conv2D, MaxPooling, BatchNorm], data_format='NCHW'), \
argscope(Conv2D, use_bias=False), \
argscope(BatchNorm, use_local_stat=False):
argscope(BatchNorm, use_local_stat=False), \
tf.variable_scope(tf.get_variable_scope(),
custom_getter=maybe_freeze_affine):
yield
......
......@@ -17,6 +17,7 @@ CLASS_NAMES = [] # NUM_CLASS strings. Will be populated later by coco loader
# basemodel ----------------------
RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101
FREEZE_AFFINE = False # do not train affine parameters inside BN
# schedule -----------------------
BASE_LR = 1e-2
......
......@@ -201,7 +201,7 @@ def BatchNorm(x, use_local_stat=None, momentum=0.9, epsilon=1e-5,
'decay': 'momentum'
})
def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
scale=True, bias=True, gamma_initializer=None,
scale=True, center=True, gamma_initializer=None,
data_format='channels_last'):
"""
Batch Renormalization layer, as described in the paper:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment