Commit 0086e156 authored by Yuxin Wu's avatar Yuxin Wu

LR warmup in imagenet-resnet

parent 9bfbb0a4
# Faster-RCNN / Mask-RCNN on COCO # Faster-RCNN / Mask-RCNN on COCO
This example aims to provide a minimal (1.3k lines) multi-GPU implementation of This example aims to provide a minimal (1.3k lines) implementation of
Faster-RCNN & Mask-RCNN (with ResNet backbones) on COCO. end-to-end Faster-RCNN & Mask-RCNN (with ResNet backbones) on COCO.
## Dependencies ## Dependencies
+ Python 3; TensorFlow >= 1.4.0 + Python 3; TensorFlow >= 1.4.0
...@@ -62,7 +62,7 @@ MaskRCNN results contain both bbox and segm mAP. ...@@ -62,7 +62,7 @@ MaskRCNN results contain both bbox and segm mAP.
The two 360k models have identical configurations with The two 360k models have identical configurations with
`R50-C4-2x` configuration in `R50-C4-2x` configuration in
[Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md#end-to-end-faster--mask-r-cnn-baselines0). [Detectron Model Zoo](https://github.com/facebookresearch/Detectron/blob/master/MODEL_ZOO.md#end-to-end-faster--mask-r-cnn-baselines).
They get the __same performance__ with the official models, and are about 14% slower than the official implementation, due to the lack of specialized ops. They get the __same performance__ with the official models, and are about 14% slower than the official implementation, due to the lack of specialized ops.
## Notes ## Notes
......
...@@ -23,8 +23,6 @@ from resnet_model import ( ...@@ -23,8 +23,6 @@ from resnet_model import (
resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck, resnet_group, resnet_basicblock, resnet_bottleneck, se_resnet_bottleneck,
resnet_backbone) resnet_backbone)
TOTAL_BATCH_SIZE = 256
class Model(ImageNetModel): class Model(ImageNetModel):
def __init__(self, depth, data_format='NCHW', mode='resnet'): def __init__(self, depth, data_format='NCHW', mode='resnet'):
...@@ -63,7 +61,7 @@ def get_data(name, batch): ...@@ -63,7 +61,7 @@ def get_data(name, batch):
def get_config(model, fake=False): def get_config(model, fake=False):
nr_tower = max(get_nr_gpu(), 1) nr_tower = max(get_nr_gpu(), 1)
batch = TOTAL_BATCH_SIZE // nr_tower batch = args.batch // nr_tower
if fake: if fake:
logger.info("For benchmark, batch size is fixed to 64 per tower.") logger.info("For benchmark, batch size is fixed to 64 per tower.")
...@@ -74,12 +72,19 @@ def get_config(model, fake=False): ...@@ -74,12 +72,19 @@ def get_config(model, fake=False):
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch)) logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
dataset_train = get_data('train', batch) dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch) dataset_val = get_data('val', batch)
BASE_LR = 0.1 * (args.batch // 256)
callbacks = [ callbacks = [
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter(
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]), 'learning_rate', [(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
HumanHyperParamSetter('learning_rate'), (85, BASE_LR * 1e-3), (95, BASE_LR * 1e-4), (105, BASE_LR * 1e-5)]),
] ]
if BASE_LR != 0.1:
callbacks.append(
ScheduledHyperParamSetter(
'learning_rate', [(0, 0.1), (3, BASE_LR)], interp='linear'))
infs = [ClassificationError('wrong-top1', 'val-error-top1'), infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')] ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1: if nr_tower == 1:
...@@ -94,7 +99,7 @@ def get_config(model, fake=False): ...@@ -94,7 +99,7 @@ def get_config(model, fake=False):
model=model, model=model,
dataflow=dataset_train, dataflow=dataset_train,
callbacks=callbacks, callbacks=callbacks,
steps_per_epoch=100 if args.fake else 5000, # 5000 ~= 1.28M / TOTAL_BATCH_SIZE steps_per_epoch=100 if args.fake else 1280000 // args.batch,
max_epoch=110, max_epoch=110,
nr_tower=nr_tower nr_tower=nr_tower
) )
...@@ -111,6 +116,8 @@ if __name__ == '__main__': ...@@ -111,6 +116,8 @@ if __name__ == '__main__':
parser.add_argument('-d', '--depth', help='resnet depth', parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101, 152]) type=int, default=18, choices=[18, 34, 50, 101, 152])
parser.add_argument('--eval', action='store_true') parser.add_argument('--eval', action='store_true')
parser.add_argument('--batch', help='total batch size. need to be multiple of 256 to get similar accuracy.',
default=256, type=int)
parser.add_argument('--mode', choices=['resnet', 'preact', 'se'], parser.add_argument('--mode', choices=['resnet', 'preact', 'se'],
help='variants of resnet to use', default='resnet') help='variants of resnet to use', default='resnet')
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment