Commit 12223dfb authored by Yuxin Wu's avatar Yuxin Wu

fix #796

parent 820bcac1
......@@ -88,7 +88,7 @@ def get_data(name, batch):
def get_config():
nr_tower = max(get_nr_gpu(), 1)
nr_tower = max(get_num_gpu(), 1)
batch = args.batch
total_batch = batch * nr_tower
if total_batch != 128:
......@@ -107,7 +107,7 @@ def get_config():
EstimatedTimeLeft(),
ScheduledHyperParamSetter(
'learning_rate',
[(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2), (80, BASE_LR * 1e-3)]),
[(0, BASE_LR), (30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2), (80, BASE_LR * 1e-3)]),
DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower))),
]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment