Commit b2a396e5 authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] configurable LR

parent 57a66ff4
...@@ -12,13 +12,21 @@ BASEDIR = '/path/to/your/COCO/DIR' ...@@ -12,13 +12,21 @@ BASEDIR = '/path/to/your/COCO/DIR'
TRAIN_DATASET = ['train2014', 'valminusminival2014'] TRAIN_DATASET = ['train2014', 'valminusminival2014']
VAL_DATASET = 'minival2014' # only support evaluation on single dataset VAL_DATASET = 'minival2014' # only support evaluation on single dataset
NUM_CLASS = 81 NUM_CLASS = 81
CLASS_NAMES = [] # NUM_CLASS strings CLASS_NAMES = [] # NUM_CLASS strings. Will be populated later by coco loader
# basemodel ---------------------- # basemodel ----------------------
RESNET_NUM_BLOCK = [3, 4, 6, 3] # resnet50 RESNET_NUM_BLOCK = [3, 4, 6, 3] # for resnet50
# RESNET_NUM_BLOCK = [3, 4, 23, 3] # resnet101 # RESNET_NUM_BLOCK = [3, 4, 23, 3] # for resnet101
# preprocessing -------------------- # schedule -----------------------
BASE_LR = 1e-2
WARMUP = 500
STEPS_PER_EPOCH = 500
LR_SCHEDULE = [150000, 230000, 280000]
# LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
# LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectron
# image resolution --------------------
SHORT_EDGE_SIZE = 800 SHORT_EDGE_SIZE = 800
MAX_SIZE = 1333 MAX_SIZE = 1333
# alternative (worse & faster) setting: 600, 1024 # alternative (worse & faster) setting: 600, 1024
......
...@@ -325,7 +325,7 @@ if __name__ == '__main__': ...@@ -325,7 +325,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--logdir', help='logdir', default='train_log/fastrcnn') parser.add_argument('--logdir', help='logdir', default='train_log/maskrcnn')
parser.add_argument('--datadir', help='override config.BASEDIR') parser.add_argument('--datadir', help='override config.BASEDIR')
parser.add_argument('--visualize', action='store_true') parser.add_argument('--visualize', action='store_true')
parser.add_argument('--evaluate', help='path to the output json eval file') parser.add_argument('--evaluate', help='path to the output json eval file')
...@@ -360,30 +360,31 @@ if __name__ == '__main__': ...@@ -360,30 +360,31 @@ if __name__ == '__main__':
else: else:
logger.set_logger_dir(args.logdir) logger.set_logger_dir(args.logdir)
print_config() print_config()
stepnum = 500
warmup_epoch = 3
factor = get_batch_factor() factor = get_batch_factor()
stepnum = config.STEPS_PER_EPOCH
warmup_epoch = max(1, config.WARMUP / stepnum)
warmup_schedule = [(0, config.BASE_LR / 3), (warmup_epoch * factor, config.BASE_LR)]
lr_schedule = [warmup_schedule[-1]]
for idx, steps in enumerate(config.LR_SCHEDULE[:-1]):
mult = 0.1 ** (idx + 1)
lr_schedule.append(
(steps * factor // stepnum, config.BASE_LR * mult))
cfg = TrainConfig( cfg = TrainConfig(
model=Model(), model=Model(),
data=QueueInput(get_train_dataflow(add_mask=config.MODE_MASK)), data=QueueInput(get_train_dataflow(add_mask=config.MODE_MASK)),
callbacks=[ callbacks=[
ModelSaver(max_to_keep=10, keep_checkpoint_every_n_hours=1), ModelSaver(max_to_keep=10, keep_checkpoint_every_n_hours=1),
# linear warmup # linear warmup # TODO step-wise linear warmup
ScheduledHyperParamSetter(
'learning_rate',
[(0, 3e-3), (warmup_epoch * factor, 1e-2)], interp='linear'),
# step decay
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
'learning_rate', 'learning_rate', warmup_schedule, interp='linear'),
[(warmup_epoch * factor, 1e-2), ScheduledHyperParamSetter('learning_rate', lr_schedule),
(150000 * factor // stepnum, 1e-3),
(230000 * factor // stepnum, 1e-4)]),
EvalCallback(), EvalCallback(),
GPUUtilizationTracker(), GPUUtilizationTracker(),
], ],
steps_per_epoch=stepnum, steps_per_epoch=stepnum,
max_epoch=280000 * factor // stepnum, max_epoch=config.LR_SCHEDULE[2] * factor // stepnum,
session_init=get_model_loader(args.load) if args.load else None, session_init=get_model_loader(args.load) if args.load else None,
) )
trainer = SyncMultiGPUTrainerReplicated(get_nr_gpu()) trainer = SyncMultiGPUTrainerReplicated(get_nr_gpu())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment