Commit 3c52b7c1 authored by Yuxin Wu's avatar Yuxin Wu

use environment variables when args.gpu is not set

parent dadd971c
......@@ -311,13 +311,12 @@ if __name__ == '__main__':
sys.exit()
assert args.gpu is not None, "Need to specify a list of gpu for training!"
NR_GPU = len(args.gpu.split(','))
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
nr_tower = max(get_nr_gpu(), 1)
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower
logger.info("Batch per tower: {}".format(BATCH_SIZE))
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
config.nr_tower = nr_tower
SyncMultiGPUTrainer(config).train()
......@@ -190,6 +190,5 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
config.nr_tower = max(get_nr_gpu(), 1)
QueueInputTrainer(config).train()
......@@ -226,6 +226,5 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
config.nr_tower = max(get_nr_gpu(), 1)
SyncMultiGPUTrainer(config).train()
......@@ -175,6 +175,5 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
config.nr_tower = max(get_nr_gpu(), 1)
SyncMultiGPUTrainer(config).train()
......@@ -93,6 +93,5 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
config.nr_tower = max(get_nr_gpu(), 1)
SyncMultiGPUTrainer(config).train()
......@@ -256,12 +256,12 @@ if __name__ == '__main__':
viz_cam(args.load, args.data)
sys.exit()
NR_GPU = len(args.gpu.split(','))
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.auto_set_dir()
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
config.nr_tower = NR_GPU
config.nr_tower = nr_gpu
SyncMultiGPUTrainer(config).train()
......@@ -63,16 +63,17 @@ def get_config():
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', required=True)
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
NR_GPU = len(args.gpu.split(','))
with change_gpu(args.gpu):
config = get_config()
config.nr_tower = NR_GPU
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.load:
config.session_init = SaverRestore(args.load)
config = get_config()
config.nr_tower = get_nr_gpu()
SyncMultiGPUTrainer(config).train()
if args.load:
config.session_init = SaverRestore(args.load)
SyncMultiGPUTrainer(config).train()
......@@ -150,10 +150,8 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
nr_gpu = get_nr_gpu()
if nr_gpu <= 1:
config.nr_tower = get_nr_gpu()
if config.nr_tower <= 1:
QueueInputTrainer(config).train()
else:
SyncMultiGPUTrainer(config).train()
......@@ -24,12 +24,12 @@ def change_gpu(val):
def get_nr_gpu():
"""
Returns:
int: the number of GPU from ``CUDA_VISIBLE_DEVICES``.
int: #available GPUs in CUDA_VISIBLE_DEVICES, or in the system.
"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if env is not None:
return len(env.split(','))
logger.info("Loading local devices by TensorFlow ...")
logger.info("Loading devices by TensorFlow ...")
from tensorflow.python.client import device_lib
device_protos = device_lib.list_local_devices()
gpus = [x.name for x in device_protos if x.device_type == 'GPU']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment