Commit 844d8e69 authored by Yuxin Wu's avatar Yuxin Wu

fix resnet

parent 2a1af832
......@@ -82,6 +82,7 @@ def get_data(train_or_test):
def get_config():
assert tf.test.is_gpu_available()
nr_gpu = get_nr_gpu()
global BATCH_SIZE
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.info("Running on {} GPUs. Batch size per GPU: {}".format(nr_gpu, BATCH_SIZE))
......
......@@ -85,6 +85,7 @@ def get_data(train_or_test):
def get_config(fake=False, data_format='NCHW'):
nr_tower = max(get_nr_gpu(), 1)
global BATCH_SIZE
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, BATCH_SIZE))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment