Commit ceb004a1 authored by Yuxin Wu's avatar Yuxin Wu

fix resnet on CPU

parent f5d5d4c2
......@@ -84,9 +84,9 @@ def get_data(train_or_test):
def get_config(fake=False, data_format='NCHW'):
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, BATCH_SIZE))
nr_tower = max(get_nr_gpu(), 1)
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, BATCH_SIZE))
if fake:
dataset_train = dataset_val = FakeData(
......@@ -109,7 +109,7 @@ def get_config(fake=False, data_format='NCHW'):
],
steps_per_epoch=5000,
max_epoch=110,
nr_tower=nr_gpu
nr_tower=nr_tower
)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment