Commit 7cb2606c authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] better scope management

parent 85586fc5
# -*- coding: utf-8 -*-
# File: basemodel.py
from contextlib import contextmanager
from contextlib import contextmanager, ExitStack
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorpack.tfutils import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.varreplace import custom_getter_scope
from tensorpack.tfutils.varreplace import custom_getter_scope, freeze_variables
from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, layer_register)
......@@ -40,13 +42,16 @@ def GroupNorm(x, group=32, gamma_initializer=tf.constant_initializer(1.)):
return tf.reshape(out, orig_shape, name='output')
def maybe_freeze_affine(getter, *args, **kwargs):
def freeze_affine_getter(getter, *args, **kwargs):
# custom getter to freeze affine params inside bn
name = args[0] if len(args) else kwargs.get('name')
if name.endswith('/gamma') or name.endswith('/beta'):
if cfg.BACKBONE.FREEZE_AFFINE:
kwargs['trainable'] = False
return getter(*args, **kwargs)
ret = getter(*args, **kwargs)
add_model_variable(ret)
else:
ret = getter(*args, **kwargs)
return ret
def maybe_reverse_pad(topleft, bottomright):
......@@ -56,26 +61,33 @@ def maybe_reverse_pad(topleft, bottomright):
@contextmanager
def backbone_argscope():
def backbone_scope(freeze):
"""
Args:
freeze (bool): whether to freeze all the variables under the scope
"""
def nonlin(x):
x = get_norm()(x)
return tf.nn.relu(x)
with argscope([Conv2D, MaxPooling, BatchNorm], data_format='channels_first'), \
argscope(Conv2D, use_bias=False, activation=nonlin), \
argscope(BatchNorm, training=False), \
custom_getter_scope(maybe_freeze_affine):
yield
argscope(Conv2D, use_bias=False, activation=nonlin,
kernel_initializer=tf.variance_scaling_initializer(
scale=2.0, mode='fan_out')), \
ExitStack() as stack:
if cfg.BACKBONE.NORM in ['FreezeBN', 'SyncBN']:
if freeze or cfg.BACKBONE.NORM == 'FreezeBN':
stack.enter_context(argscope(BatchNorm, training=False))
else:
stack.enter_context(argscope(
BatchNorm, sync_statistics='nccl' if cfg.TRAINER == 'replicated' else 'horovod'))
@contextmanager
def maybe_syncbn_scope():
if cfg.BACKBONE.NORM == 'SyncBN':
assert cfg.BACKBONE.FREEZE_AT == 2 # TODO add better support
with argscope(BatchNorm, training=None,
sync_statistics='nccl' if cfg.TRAINER == 'replicated' else 'horovod'):
yield
if freeze:
stack.enter_context(freeze_variables(stop_gradient=False, skip_collection=True))
else:
# the layers are not completely freezed, but we may want to only freeze the affine
if cfg.BACKBONE.FREEZE_AFFINE:
stack.enter_context(custom_getter_scope(freeze_affine_getter))
yield
......@@ -147,16 +159,16 @@ def resnet_group(name, l, block_func, features, count, stride):
def resnet_c4_backbone(image, num_blocks):
assert len(num_blocks) == 3
with backbone_argscope():
freeze_at = cfg.BACKBONE.FREEZE_AT
with backbone_scope(freeze=freeze_at > 0):
l = tf.pad(image, [[0, 0], [0, 0], maybe_reverse_pad(2, 3), maybe_reverse_pad(2, 3)])
l = Conv2D('conv0', l, 64, 7, strides=2, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
with backbone_scope(freeze=freeze_at > 1):
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
# TODO replace var by const to enable optimization
if cfg.BACKBONE.FREEZE_AT == 2:
c2 = tf.stop_gradient(c2)
with maybe_syncbn_scope():
with backbone_scope(freeze=False):
c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2)
c4 = resnet_group('group2', c3, resnet_bottleneck, 256, num_blocks[2], 2)
# 16x downsampling up to now
......@@ -165,18 +177,19 @@ def resnet_c4_backbone(image, num_blocks):
@auto_reuse_variable_scope
def resnet_conv5(image, num_block):
with backbone_argscope(), maybe_syncbn_scope():
with backbone_scope(freeze=False):
l = resnet_group('group3', image, resnet_bottleneck, 512, num_block, 2)
return l
def resnet_fpn_backbone(image, num_blocks):
freeze_at = cfg.BACKBONE.FREEZE_AT
shape2d = tf.shape(image)[2:]
mult = float(cfg.FPN.RESOLUTION_REQUIREMENT)
new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult)
pad_shape2d = new_shape2d - shape2d
assert len(num_blocks) == 4, num_blocks
with backbone_argscope():
with backbone_scope(freeze=freeze_at > 0):
chan = image.shape[1]
pad_base = maybe_reverse_pad(2, 3)
l = tf.pad(image, tf.stack(
......@@ -187,10 +200,9 @@ def resnet_fpn_backbone(image, num_blocks):
l = Conv2D('conv0', l, 64, 7, strides=2, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
with backbone_scope(freeze=freeze_at > 1):
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
if cfg.BACKBONE.FREEZE_AT == 2:
c2 = tf.stop_gradient(c2)
with maybe_syncbn_scope():
with backbone_scope(freeze=False):
c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2)
c4 = resnet_group('group2', c3, resnet_bottleneck, 256, num_blocks[2], 2)
c5 = resnet_group('group3', c4, resnet_bottleneck, 512, num_blocks[3], 2)
......
......@@ -72,7 +72,7 @@ _C.BACKBONE.RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101
_C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside norm layers
_C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN
_C.BACKBONE.FREEZE_AT = 2 # options: 0, 2
_C.BACKBONE.FREEZE_AT = 2 # options: 0, 1, 2
# Use a base model with TF-preferred padding mode,
# which may pad more pixels on right/bottom than top/left.
......@@ -169,7 +169,7 @@ def finalize_configs(is_training):
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN':
assert not _C.BACKBONE.FREEZE_AFFINE
assert _C.BACKBONE.FREEZE_AT in [0, 2]
assert _C.BACKBONE.FREEZE_AT in [0, 1, 2]
_C.RPN.NUM_ANCHOR = len(_C.RPN.ANCHOR_SIZES) * len(_C.RPN.ANCHOR_RATIOS)
assert len(_C.FPN.ANCHOR_STRIDES) == len(_C.RPN.ANCHOR_SIZES)
......
......@@ -234,8 +234,7 @@ class ResNetC4Model(DetectionModel):
mrcnn_loss = 0.0
wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
'.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
total_cost = tf.add_n([
rpn_label_loss, rpn_box_loss,
......@@ -372,8 +371,7 @@ class ResNetFPNModel(DetectionModel):
mrcnn_loss = 0.0
wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
'.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
total_cost = tf.add_n([rpn_label_loss, rpn_box_loss,
fastrcnn_label_loss, fastrcnn_box_loss,
......
......@@ -20,7 +20,7 @@ class ModelSaver(Callback):
def __init__(self, max_to_keep=10,
keep_checkpoint_every_n_hours=0.5,
checkpoint_dir=None,
var_collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES]):
var_collections=[tf.GraphKeys.GLOBAL_VARIABLES]):
"""
Args:
max_to_keep (int): the same as in ``tf.train.Saver``.
......
......@@ -189,6 +189,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
# because during training, EMA isn't used
if ctx.is_main_training_tower:
for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable):
add_model_variable(v)
if not ctx.is_main_training_tower or internal_update:
restore_collection(coll_bk)
......@@ -351,6 +352,7 @@ def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
if ctx.is_main_training_tower:
for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable):
add_model_variable(v)
else:
# only run UPDATE_OPS in the first tower
......
......@@ -3,11 +3,12 @@
# Credit: Qinyao He
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from contextlib import contextmanager
from .common import get_tf_version_tuple
__all__ = ['freeze_variables', 'remap_variables']
__all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables']
@contextmanager
......@@ -65,19 +66,25 @@ def freeze_variables(stop_gradient=True, skip_collection=False):
Args:
stop_gradient (bool): if True, variables returned from `get_variable`
will be wrapped with `tf.stop_gradient` and therefore has no
gradient when used later. Note that the created variables may
still have gradient when accessed by other approaches (e.g.
by name, or by collection).
gradient when used later.
Note that the created variables may still have gradient when accessed
by other approaches (e.g. by name, or by collection).
Also note that this makes `tf.get_variable` returns a Tensor instead of a Variable,
which may break existing code.
Therefore, it's recommended to use the `skip_collection` option instead.
skip_collection (bool): if True, do not add the variable to
``TRAINABLE_VARIABLES`` collection. As a result they will not be
trained by default.
``TRAINABLE_VARIABLES`` collection, but to ``MODEL_VARIABLES``
collection. As a result they will not be trained by default.
"""
def custom_getter(getter, *args, **kwargs):
trainable = kwargs.get('trainable', True)
name = args[0] if len(args) else kwargs.get('name')
if skip_collection:
kwargs['trainable'] = False
v = getter(*args, **kwargs)
if skip_collection:
add_model_variable(v)
if trainable and stop_gradient:
v = tf.stop_gradient(v)
v = tf.stop_gradient(v, name='freezed_' + name)
return v
return custom_getter_scope(custom_getter)
......@@ -4,7 +4,8 @@
import tensorflow as tf
from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInput, DummyConstantInput)
InputSource, FeedInput, FeedfreeInput,
QueueInput, StagingInput, DummyConstantInput)
from ..utils import logger
from .config import TrainConfig
......@@ -40,7 +41,8 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
# seem to only improve on >1 GPUs
assert not isinstance(trainer, SimpleTrainer)
if not isinstance(input, (StagingInput, DummyConstantInput)):
if isinstance(input, FeedfreeInput) and \
not isinstance(input, (StagingInput, DummyConstantInput)):
logger.info("Automatically applying StagingInput on the DataFlow.")
input = StagingInput(input)
return input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment