Commit ceb004a1 authored by Yuxin Wu's avatar Yuxin Wu

fix resnet on CPU

parent f5d5d4c2
...@@ -84,9 +84,9 @@ def get_data(train_or_test): ...@@ -84,9 +84,9 @@ def get_data(train_or_test):
def get_config(fake=False, data_format='NCHW'): def get_config(fake=False, data_format='NCHW'):
nr_gpu = get_nr_gpu() nr_tower = max(get_nr_gpu(), 1)
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, BATCH_SIZE)) logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, BATCH_SIZE))
if fake: if fake:
dataset_train = dataset_val = FakeData( dataset_train = dataset_val = FakeData(
...@@ -109,7 +109,7 @@ def get_config(fake=False, data_format='NCHW'): ...@@ -109,7 +109,7 @@ def get_config(fake=False, data_format='NCHW'):
], ],
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=110, max_epoch=110,
nr_tower=nr_gpu nr_tower=nr_tower
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment