Commit 7cb2606c authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] better scope management

parent 85586fc5
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: basemodel.py # File: basemodel.py
from contextlib import contextmanager from contextlib import contextmanager, ExitStack
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorpack.tfutils import argscope from tensorpack.tfutils import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from tensorpack.tfutils.varreplace import custom_getter_scope from tensorpack.tfutils.varreplace import custom_getter_scope, freeze_variables
from tensorpack.models import ( from tensorpack.models import (
Conv2D, MaxPooling, BatchNorm, layer_register) Conv2D, MaxPooling, BatchNorm, layer_register)
...@@ -40,13 +42,16 @@ def GroupNorm(x, group=32, gamma_initializer=tf.constant_initializer(1.)): ...@@ -40,13 +42,16 @@ def GroupNorm(x, group=32, gamma_initializer=tf.constant_initializer(1.)):
return tf.reshape(out, orig_shape, name='output') return tf.reshape(out, orig_shape, name='output')
def maybe_freeze_affine(getter, *args, **kwargs): def freeze_affine_getter(getter, *args, **kwargs):
# custom getter to freeze affine params inside bn # custom getter to freeze affine params inside bn
name = args[0] if len(args) else kwargs.get('name') name = args[0] if len(args) else kwargs.get('name')
if name.endswith('/gamma') or name.endswith('/beta'): if name.endswith('/gamma') or name.endswith('/beta'):
if cfg.BACKBONE.FREEZE_AFFINE: kwargs['trainable'] = False
kwargs['trainable'] = False ret = getter(*args, **kwargs)
return getter(*args, **kwargs) add_model_variable(ret)
else:
ret = getter(*args, **kwargs)
return ret
def maybe_reverse_pad(topleft, bottomright): def maybe_reverse_pad(topleft, bottomright):
...@@ -56,26 +61,33 @@ def maybe_reverse_pad(topleft, bottomright): ...@@ -56,26 +61,33 @@ def maybe_reverse_pad(topleft, bottomright):
@contextmanager @contextmanager
def backbone_argscope(): def backbone_scope(freeze):
"""
Args:
freeze (bool): whether to freeze all the variables under the scope
"""
def nonlin(x): def nonlin(x):
x = get_norm()(x) x = get_norm()(x)
return tf.nn.relu(x) return tf.nn.relu(x)
with argscope([Conv2D, MaxPooling, BatchNorm], data_format='channels_first'), \ with argscope([Conv2D, MaxPooling, BatchNorm], data_format='channels_first'), \
argscope(Conv2D, use_bias=False, activation=nonlin), \ argscope(Conv2D, use_bias=False, activation=nonlin,
argscope(BatchNorm, training=False), \ kernel_initializer=tf.variance_scaling_initializer(
custom_getter_scope(maybe_freeze_affine): scale=2.0, mode='fan_out')), \
yield ExitStack() as stack:
if cfg.BACKBONE.NORM in ['FreezeBN', 'SyncBN']:
if freeze or cfg.BACKBONE.NORM == 'FreezeBN':
@contextmanager stack.enter_context(argscope(BatchNorm, training=False))
def maybe_syncbn_scope(): else:
if cfg.BACKBONE.NORM == 'SyncBN': stack.enter_context(argscope(
assert cfg.BACKBONE.FREEZE_AT == 2 # TODO add better support BatchNorm, sync_statistics='nccl' if cfg.TRAINER == 'replicated' else 'horovod'))
with argscope(BatchNorm, training=None,
sync_statistics='nccl' if cfg.TRAINER == 'replicated' else 'horovod'): if freeze:
yield stack.enter_context(freeze_variables(stop_gradient=False, skip_collection=True))
else: else:
# the layers are not completely freezed, but we may want to only freeze the affine
if cfg.BACKBONE.FREEZE_AFFINE:
stack.enter_context(custom_getter_scope(freeze_affine_getter))
yield yield
...@@ -147,36 +159,37 @@ def resnet_group(name, l, block_func, features, count, stride): ...@@ -147,36 +159,37 @@ def resnet_group(name, l, block_func, features, count, stride):
def resnet_c4_backbone(image, num_blocks): def resnet_c4_backbone(image, num_blocks):
assert len(num_blocks) == 3 assert len(num_blocks) == 3
with backbone_argscope(): freeze_at = cfg.BACKBONE.FREEZE_AT
with backbone_scope(freeze=freeze_at > 0):
l = tf.pad(image, [[0, 0], [0, 0], maybe_reverse_pad(2, 3), maybe_reverse_pad(2, 3)]) l = tf.pad(image, [[0, 0], [0, 0], maybe_reverse_pad(2, 3), maybe_reverse_pad(2, 3)])
l = Conv2D('conv0', l, 64, 7, strides=2, padding='VALID') l = Conv2D('conv0', l, 64, 7, strides=2, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)]) l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID') l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
with backbone_scope(freeze=freeze_at > 1):
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1) c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
# TODO replace var by const to enable optimization with backbone_scope(freeze=False):
if cfg.BACKBONE.FREEZE_AT == 2: c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2)
c2 = tf.stop_gradient(c2) c4 = resnet_group('group2', c3, resnet_bottleneck, 256, num_blocks[2], 2)
with maybe_syncbn_scope():
c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2)
c4 = resnet_group('group2', c3, resnet_bottleneck, 256, num_blocks[2], 2)
# 16x downsampling up to now # 16x downsampling up to now
return c4 return c4
@auto_reuse_variable_scope @auto_reuse_variable_scope
def resnet_conv5(image, num_block): def resnet_conv5(image, num_block):
with backbone_argscope(), maybe_syncbn_scope(): with backbone_scope(freeze=False):
l = resnet_group('group3', image, resnet_bottleneck, 512, num_block, 2) l = resnet_group('group3', image, resnet_bottleneck, 512, num_block, 2)
return l return l
def resnet_fpn_backbone(image, num_blocks): def resnet_fpn_backbone(image, num_blocks):
freeze_at = cfg.BACKBONE.FREEZE_AT
shape2d = tf.shape(image)[2:] shape2d = tf.shape(image)[2:]
mult = float(cfg.FPN.RESOLUTION_REQUIREMENT) mult = float(cfg.FPN.RESOLUTION_REQUIREMENT)
new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult) new_shape2d = tf.to_int32(tf.ceil(tf.to_float(shape2d) / mult) * mult)
pad_shape2d = new_shape2d - shape2d pad_shape2d = new_shape2d - shape2d
assert len(num_blocks) == 4, num_blocks assert len(num_blocks) == 4, num_blocks
with backbone_argscope(): with backbone_scope(freeze=freeze_at > 0):
chan = image.shape[1] chan = image.shape[1]
pad_base = maybe_reverse_pad(2, 3) pad_base = maybe_reverse_pad(2, 3)
l = tf.pad(image, tf.stack( l = tf.pad(image, tf.stack(
...@@ -187,13 +200,12 @@ def resnet_fpn_backbone(image, num_blocks): ...@@ -187,13 +200,12 @@ def resnet_fpn_backbone(image, num_blocks):
l = Conv2D('conv0', l, 64, 7, strides=2, padding='VALID') l = Conv2D('conv0', l, 64, 7, strides=2, padding='VALID')
l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)]) l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)])
l = MaxPooling('pool0', l, 3, strides=2, padding='VALID') l = MaxPooling('pool0', l, 3, strides=2, padding='VALID')
with backbone_scope(freeze=freeze_at > 1):
c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1) c2 = resnet_group('group0', l, resnet_bottleneck, 64, num_blocks[0], 1)
if cfg.BACKBONE.FREEZE_AT == 2: with backbone_scope(freeze=False):
c2 = tf.stop_gradient(c2) c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2)
with maybe_syncbn_scope(): c4 = resnet_group('group2', c3, resnet_bottleneck, 256, num_blocks[2], 2)
c3 = resnet_group('group1', c2, resnet_bottleneck, 128, num_blocks[1], 2) c5 = resnet_group('group3', c4, resnet_bottleneck, 512, num_blocks[3], 2)
c4 = resnet_group('group2', c3, resnet_bottleneck, 256, num_blocks[2], 2)
c5 = resnet_group('group3', c4, resnet_bottleneck, 512, num_blocks[3], 2)
# 32x downsampling up to now # 32x downsampling up to now
# size of c5: ceil(input/32) # size of c5: ceil(input/32)
return c2, c3, c4, c5 return c2, c3, c4, c5
...@@ -72,7 +72,7 @@ _C.BACKBONE.RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50 ...@@ -72,7 +72,7 @@ _C.BACKBONE.RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101 # RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101
_C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside norm layers _C.BACKBONE.FREEZE_AFFINE = False # do not train affine parameters inside norm layers
_C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN _C.BACKBONE.NORM = 'FreezeBN' # options: FreezeBN, SyncBN, GN
_C.BACKBONE.FREEZE_AT = 2 # options: 0, 2 _C.BACKBONE.FREEZE_AT = 2 # options: 0, 1, 2
# Use a base model with TF-preferred padding mode, # Use a base model with TF-preferred padding mode,
# which may pad more pixels on right/bottom than top/left. # which may pad more pixels on right/bottom than top/left.
...@@ -169,7 +169,7 @@ def finalize_configs(is_training): ...@@ -169,7 +169,7 @@ def finalize_configs(is_training):
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN': if _C.BACKBONE.NORM != 'FreezeBN':
assert not _C.BACKBONE.FREEZE_AFFINE assert not _C.BACKBONE.FREEZE_AFFINE
assert _C.BACKBONE.FREEZE_AT in [0, 2] assert _C.BACKBONE.FREEZE_AT in [0, 1, 2]
_C.RPN.NUM_ANCHOR = len(_C.RPN.ANCHOR_SIZES) * len(_C.RPN.ANCHOR_RATIOS) _C.RPN.NUM_ANCHOR = len(_C.RPN.ANCHOR_SIZES) * len(_C.RPN.ANCHOR_RATIOS)
assert len(_C.FPN.ANCHOR_STRIDES) == len(_C.RPN.ANCHOR_SIZES) assert len(_C.FPN.ANCHOR_STRIDES) == len(_C.RPN.ANCHOR_SIZES)
......
...@@ -234,8 +234,7 @@ class ResNetC4Model(DetectionModel): ...@@ -234,8 +234,7 @@ class ResNetC4Model(DetectionModel):
mrcnn_loss = 0.0 mrcnn_loss = 0.0
wd_cost = regularize_cost( wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', '.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
total_cost = tf.add_n([ total_cost = tf.add_n([
rpn_label_loss, rpn_box_loss, rpn_label_loss, rpn_box_loss,
...@@ -372,8 +371,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -372,8 +371,7 @@ class ResNetFPNModel(DetectionModel):
mrcnn_loss = 0.0 mrcnn_loss = 0.0
wd_cost = regularize_cost( wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fpn|fastrcnn|maskrcnn)/.*W', '.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost')
total_cost = tf.add_n([rpn_label_loss, rpn_box_loss, total_cost = tf.add_n([rpn_label_loss, rpn_box_loss,
fastrcnn_label_loss, fastrcnn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss,
......
...@@ -20,7 +20,7 @@ class ModelSaver(Callback): ...@@ -20,7 +20,7 @@ class ModelSaver(Callback):
def __init__(self, max_to_keep=10, def __init__(self, max_to_keep=10,
keep_checkpoint_every_n_hours=0.5, keep_checkpoint_every_n_hours=0.5,
checkpoint_dir=None, checkpoint_dir=None,
var_collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES]): var_collections=[tf.GraphKeys.GLOBAL_VARIABLES]):
""" """
Args: Args:
max_to_keep (int): the same as in ``tf.train.Saver``. max_to_keep (int): the same as in ``tf.train.Saver``.
......
...@@ -189,7 +189,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -189,7 +189,8 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
# because during training, EMA isn't used # because during training, EMA isn't used
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
for v in layer.non_trainable_variables: for v in layer.non_trainable_variables:
add_model_variable(v) if isinstance(v, tf.Variable):
add_model_variable(v)
if not ctx.is_main_training_tower or internal_update: if not ctx.is_main_training_tower or internal_update:
restore_collection(coll_bk) restore_collection(coll_bk)
...@@ -351,7 +352,8 @@ def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5, ...@@ -351,7 +352,8 @@ def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
for v in layer.non_trainable_variables: for v in layer.non_trainable_variables:
add_model_variable(v) if isinstance(v, tf.Variable):
add_model_variable(v)
else: else:
# only run UPDATE_OPS in the first tower # only run UPDATE_OPS in the first tower
restore_collection(coll_bk) restore_collection(coll_bk)
......
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
# Credit: Qinyao He # Credit: Qinyao He
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from contextlib import contextmanager from contextlib import contextmanager
from .common import get_tf_version_tuple from .common import get_tf_version_tuple
__all__ = ['freeze_variables', 'remap_variables'] __all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables']
@contextmanager @contextmanager
...@@ -65,19 +66,25 @@ def freeze_variables(stop_gradient=True, skip_collection=False): ...@@ -65,19 +66,25 @@ def freeze_variables(stop_gradient=True, skip_collection=False):
Args: Args:
stop_gradient (bool): if True, variables returned from `get_variable` stop_gradient (bool): if True, variables returned from `get_variable`
will be wrapped with `tf.stop_gradient` and therefore has no will be wrapped with `tf.stop_gradient` and therefore has no
gradient when used later. Note that the created variables may gradient when used later.
still have gradient when accessed by other approaches (e.g. Note that the created variables may still have gradient when accessed
by name, or by collection). by other approaches (e.g. by name, or by collection).
Also note that this makes `tf.get_variable` returns a Tensor instead of a Variable,
which may break existing code.
Therefore, it's recommended to use the `skip_collection` option instead.
skip_collection (bool): if True, do not add the variable to skip_collection (bool): if True, do not add the variable to
``TRAINABLE_VARIABLES`` collection. As a result they will not be ``TRAINABLE_VARIABLES`` collection, but to ``MODEL_VARIABLES``
trained by default. collection. As a result they will not be trained by default.
""" """
def custom_getter(getter, *args, **kwargs): def custom_getter(getter, *args, **kwargs):
trainable = kwargs.get('trainable', True) trainable = kwargs.get('trainable', True)
name = args[0] if len(args) else kwargs.get('name')
if skip_collection: if skip_collection:
kwargs['trainable'] = False kwargs['trainable'] = False
v = getter(*args, **kwargs) v = getter(*args, **kwargs)
if skip_collection:
add_model_variable(v)
if trainable and stop_gradient: if trainable and stop_gradient:
v = tf.stop_gradient(v) v = tf.stop_gradient(v, name='freezed_' + name)
return v return v
return custom_getter_scope(custom_getter) return custom_getter_scope(custom_getter)
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
import tensorflow as tf import tensorflow as tf
from ..input_source import ( from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInput, DummyConstantInput) InputSource, FeedInput, FeedfreeInput,
QueueInput, StagingInput, DummyConstantInput)
from ..utils import logger from ..utils import logger
from .config import TrainConfig from .config import TrainConfig
...@@ -40,7 +41,8 @@ def apply_default_prefetch(input_source_or_dataflow, trainer): ...@@ -40,7 +41,8 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
# seem to only improve on >1 GPUs # seem to only improve on >1 GPUs
assert not isinstance(trainer, SimpleTrainer) assert not isinstance(trainer, SimpleTrainer)
if not isinstance(input, (StagingInput, DummyConstantInput)): if isinstance(input, FeedfreeInput) and \
not isinstance(input, (StagingInput, DummyConstantInput)):
logger.info("Automatically applying StagingInput on the DataFlow.") logger.info("Automatically applying StagingInput on the DataFlow.")
input = StagingInput(input) input = StagingInput(input)
return input return input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment