Commit ebd7332d authored by Yuxin Wu's avatar Yuxin Wu

[ResNet] Correct learning rate for batch<256

parent 49ab85e8
...@@ -77,8 +77,9 @@ def get_config(model, fake=False): ...@@ -77,8 +77,9 @@ def get_config(model, fake=False):
ModelSaver(), ModelSaver(),
EstimatedTimeLeft(), EstimatedTimeLeft(),
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
'learning_rate', [(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2), 'learning_rate', [
(90, BASE_LR * 1e-3), (100, BASE_LR * 1e-4)]), (0, min(START_LR, BASE_LR)), (30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
(90, BASE_LR * 1e-3), (100, BASE_LR * 1e-4)]),
] ]
if BASE_LR > START_LR: if BASE_LR > START_LR:
callbacks.append( callbacks.append(
...@@ -138,7 +139,7 @@ if __name__ == '__main__': ...@@ -138,7 +139,7 @@ if __name__ == '__main__':
logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd') logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd')
else: else:
logger.set_logger_dir( logger.set_logger_dir(
os.path.join('train_log', 'imagenet-{}-d{}'.format(args.mode, args.depth))) os.path.join('train_log', 'imagenet-{}-d{}-batch{}'.format(args.mode, args.depth, args.batch)))
config = get_config(model, fake=args.fake) config = get_config(model, fake=args.fake)
if args.load: if args.load:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment