Commit 3c52b7c1 authored by Yuxin Wu's avatar Yuxin Wu

use environment variables when args.gpu is not set

parent dadd971c
...@@ -311,13 +311,12 @@ if __name__ == '__main__': ...@@ -311,13 +311,12 @@ if __name__ == '__main__':
sys.exit() sys.exit()
assert args.gpu is not None, "Need to specify a list of gpu for training!" assert args.gpu is not None, "Need to specify a list of gpu for training!"
NR_GPU = len(args.gpu.split(',')) nr_tower = max(get_nr_gpu(), 1)
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower
logger.info("Batch per tower: {}".format(BATCH_SIZE)) logger.info("Batch per tower: {}".format(BATCH_SIZE))
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: config.nr_tower = nr_tower
config.nr_tower = len(args.gpu.split(','))
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -190,6 +190,5 @@ if __name__ == '__main__': ...@@ -190,6 +190,5 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: config.nr_tower = max(get_nr_gpu(), 1)
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train() QueueInputTrainer(config).train()
...@@ -226,6 +226,5 @@ if __name__ == '__main__': ...@@ -226,6 +226,5 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
if args.gpu: config.nr_tower = max(get_nr_gpu(), 1)
config.nr_tower = len(args.gpu.split(','))
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -175,6 +175,5 @@ if __name__ == '__main__': ...@@ -175,6 +175,5 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: config.nr_tower = max(get_nr_gpu(), 1)
config.nr_tower = len(args.gpu.split(','))
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -93,6 +93,5 @@ if __name__ == '__main__': ...@@ -93,6 +93,5 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: config.nr_tower = max(get_nr_gpu(), 1)
config.nr_tower = len(args.gpu.split(','))
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -256,12 +256,12 @@ if __name__ == '__main__': ...@@ -256,12 +256,12 @@ if __name__ == '__main__':
viz_cam(args.load, args.data) viz_cam(args.load, args.data)
sys.exit() sys.exit()
NR_GPU = len(args.gpu.split(',')) nr_gpu = get_nr_gpu()
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU BATCH_SIZE = TOTAL_BATCH_SIZE // nr_gpu
logger.auto_set_dir() logger.auto_set_dir()
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = get_model_loader(args.load) config.session_init = get_model_loader(args.load)
config.nr_tower = NR_GPU config.nr_tower = nr_gpu
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -63,14 +63,15 @@ def get_config(): ...@@ -63,14 +63,15 @@ def get_config():
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', required=True) parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
NR_GPU = len(args.gpu.split(',')) if args.gpu:
with change_gpu(args.gpu): os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config() config = get_config()
config.nr_tower = NR_GPU config.nr_tower = get_nr_gpu()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
...@@ -150,10 +150,8 @@ if __name__ == '__main__': ...@@ -150,10 +150,8 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: config.nr_tower = get_nr_gpu()
config.nr_tower = len(args.gpu.split(',')) if config.nr_tower <= 1:
nr_gpu = get_nr_gpu()
if nr_gpu <= 1:
QueueInputTrainer(config).train() QueueInputTrainer(config).train()
else: else:
SyncMultiGPUTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -24,12 +24,12 @@ def change_gpu(val): ...@@ -24,12 +24,12 @@ def change_gpu(val):
def get_nr_gpu(): def get_nr_gpu():
""" """
Returns: Returns:
int: the number of GPU from ``CUDA_VISIBLE_DEVICES``. int: #available GPUs in CUDA_VISIBLE_DEVICES, or in the system.
""" """
env = os.environ.get('CUDA_VISIBLE_DEVICES', None) env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if env is not None: if env is not None:
return len(env.split(',')) return len(env.split(','))
logger.info("Loading local devices by TensorFlow ...") logger.info("Loading devices by TensorFlow ...")
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib
device_protos = device_lib.list_local_devices() device_protos = device_lib.list_local_devices()
gpus = [x.name for x in device_protos if x.device_type == 'GPU'] gpus = [x.name for x in device_protos if x.device_type == 'GPU']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment