Commit 0086e156 authored by Yuxin Wu's avatar Yuxin Wu

LR warmup in imagenet-resnet

parent 9bfbb0a4
# Faster-RCNN / Mask-RCNN on COCO
This example aims to provide a minimal (1.3k lines) multi-GPU implementation of
Faster-RCNN & Mask-RCNN (with ResNet backbones) on COCO.
This example aims to provide a minimal (1.3k lines) implementation of
end-to-end Faster-RCNN & Mask-RCNN (with ResNet backbones) on COCO.
## Dependencies
+ Python 3; TensorFlow >= 1.4.0
......@@ -62,7 +62,7 @@ MaskRCNN results contain both bbox and segm mAP.
The two 360k models have identical configurations with
`R50-C4-2x` configuration in
[Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md#end-to-end-faster--mask-r-cnn-baselines0).
[Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md#end-to-end-faster--mask-r-cnn-baselines).
They get the __same performance__ with the official models, and are about 14% slower than the official implementation, due to the lack of specialized ops.
## Notes
......
......@@ -23,8 +23,6 @@ from resnet_model import (
resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck,
resnet_backbone)
TOTAL_BATCH_SIZE = 256
class Model(ImageNetModel):
def __init__(self, depth, data_format='NCHW', mode='resnet'):
......@@ -63,7 +61,7 @@ def get_data(name, batch):
def get_config(model, fake=False):
nr_tower = max(get_nr_gpu(), 1)
batch = TOTAL_BATCH_SIZE // nr_tower
batch = args.batch // nr_tower
if fake:
logger.info("For benchmark, batch size is fixed to 64 per tower.")
......@@ -74,12 +72,19 @@ def get_config(model, fake=False):
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch)
BASE_LR = 0.1 * (args.batch // 256)
callbacks = [
ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]),
HumanHyperParamSetter('learning_rate'),
ScheduledHyperParamSetter(
'learning_rate', [(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
(85, BASE_LR * 1e-3), (95, BASE_LR * 1e-4), (105, BASE_LR * 1e-5)]),
]
if BASE_LR != 0.1:
callbacks.append(
ScheduledHyperParamSetter(
'learning_rate', [(0, 0.1), (3, BASE_LR)], interp='linear'))
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1:
......@@ -94,7 +99,7 @@ def get_config(model, fake=False):
model=model,
dataflow=dataset_train,
callbacks=callbacks,
steps_per_epoch=100 if args.fake else 5000, # 5000 ~= 1.28M / TOTAL_BATCH_SIZE
steps_per_epoch=100 if args.fake else 1280000 // args.batch,
max_epoch=110,
nr_tower=nr_tower
)
......@@ -111,6 +116,8 @@ if __name__ == '__main__':
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101, 152])
parser.add_argument('--eval', action='store_true')
parser.add_argument('--batch', help='total batch size. need to be multiple of 256 to get similar accuracy.',
default=256, type=int)
parser.add_argument('--mode', choices=['resnet', 'preact', 'se'],
help='variants of resnet to use', default='resnet')
args = parser.parse_args()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment