Commit a4ae0f2d authored by Yuxin Wu's avatar Yuxin Wu

hide gpu.* from automatic import

parent ab0048cc
......@@ -24,6 +24,8 @@ from tensorpack.utils.serialize import *
from tensorpack.utils.stats import *
from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.RL import *
from simulator import *
......
......@@ -16,6 +16,8 @@ from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.tfutils.varreplace import remap_variables
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
from dorefa import get_dorefa
"""
......
......@@ -191,5 +191,4 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = max(get_nr_gpu(), 1)
QueueInputTrainer(config).train()
......@@ -12,6 +12,7 @@ import multiprocessing
from tensorpack import *
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils.viz import *
from tensorpack.utils.argtools import shape2d, shape4d
from tensorpack.dataflow import dataset
......
......@@ -8,6 +8,7 @@ import argparse
from tensorpack import *
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils.globvars import globalns as G
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
import tensorflow as tf
......
......@@ -14,6 +14,7 @@ import sys
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import *
......
......@@ -10,6 +10,7 @@ import os
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.dataflow import dataset
import tensorflow as tf
......
......@@ -14,6 +14,7 @@ from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import (
fbresnet_augmentor, apply_preactivation, resnet_shortcut, resnet_backbone,
......
......@@ -14,6 +14,7 @@ from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
from imagenet_resnet_utils import (
fbresnet_augmentor, resnet_basicblock, resnet_bottleneck, resnet_backbone,
......
......@@ -11,6 +11,7 @@ from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
import tensorflow as tf
"""
......
......@@ -16,6 +16,7 @@ from tensorpack.dataflow import dataset
from tensorpack.tfutils import optimizer
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils import viz
from imagenet_resnet_utils import (
......
......@@ -6,13 +6,16 @@
import numpy as np
import os
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
import argparse
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.gpu import change_gpu
from embedding_data import get_test_data, MnistPairs, MnistTriplets
MATPLOTLIB_AVAIBLABLE = False
......@@ -211,7 +214,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('-a', '--algorithm', help='used algorithm', type=str,
parser.add_argument('-a', '--algorithm', help='used algorithm', required=True,
choices=["siamese", "cosine", "triplet", "softtriplet"])
parser.add_argument('--visualize', help='export embeddings into an image', action='store_true')
args = parser.parse_args()
......
......@@ -5,6 +5,7 @@
import os
import argparse
from tensorpack import *
from tensorpack.utils.gpu import get_nr_gpu
import tensorflow as tf
"""
......@@ -71,8 +72,9 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
config.nr_tower = get_nr_gpu()
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
if args.load:
config.session_init = SaverRestore(args.load)
......
......@@ -11,7 +11,6 @@ import os
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
"""
A small convnet model for Cifar10 or Cifar100 dataset.
......@@ -151,7 +150,7 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = max(get_nr_gpu(), 1)
config.nr_tower = max(len(args.gpu.split(',')), 1)
if config.nr_tower <= 1:
QueueInputTrainer(config).train()
else:
......
......@@ -132,7 +132,7 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = get_nr_gpu()
config.nr_tower = len(args.gpu.split(','))
if config.nr_tower > 1:
SyncMultiGPUTrainer(config).train()
else:
......
......@@ -24,9 +24,26 @@ def _global_import(name):
_TO_IMPORT = set([
'naming',
'utils',
'gpu' # TODO don't export it
])
# this two functions for back-compat only
def get_nr_gpu():
from .gpu import get_nr_gpu
logger.warn( # noqa
"get_nr_gpu will not be automatically imported any more! "
"Please do `from tensorpack.utils.gpu import get_nr_gpu`")
return get_nr_gpu()
def change_gpu(val):
from .gpu import change_gpu as cg
logger.warn( # noqa
"change_gpu will not be automatically imported any more! "
"Please do `from tensorpack.utils.gpu import change_gpu`")
return cg(val)
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
......@@ -37,4 +54,6 @@ for _, module_name, _ in iter_modules(
continue
if module_name in _TO_IMPORT:
_global_import(module_name)
__all__.extend(['logger'])
__all__.extend([
'logger',
'get_nr_gpu', 'change_gpu'])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment