Commit 57af140a authored by Yuxin Wu's avatar Yuxin Wu

change the default dataset location because tensorpack may be installed to system

parent 1fa7ce9a
...@@ -11,7 +11,8 @@ from collections import deque ...@@ -11,7 +11,8 @@ from collections import deque
import threading import threading
import six import six
from six.moves import range from six.moves import range
from tensorpack.utils import (get_rng, logger, get_dataset_path, execute_only_once) from tensorpack.utils import (get_rng, logger, execute_only_once)
from tensorpack.utils.fs import get_dataset_path
from tensorpack.utils.stats import StatCounter from tensorpack.utils.stats import StatCounter
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
......
...@@ -10,8 +10,8 @@ import argparse ...@@ -10,8 +10,8 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.gradproc import * from tensorpack.tfutils.gradproc import *
from tensorpack.utils import logger, get_dataset_path from tensorpack.utils import logger
from tensorpack.utils.fs import download from tensorpack.utils.fs import download, get_dataset_path
from tensorpack.utils.argtools import memoized_ignoreargs from tensorpack.utils.argtools import memoized_ignoreargs
import reader as tfreader import reader as tfreader
......
...@@ -8,8 +8,7 @@ import glob ...@@ -8,8 +8,7 @@ import glob
import cv2 import cv2
import numpy as np import numpy as np
from ...utils import get_dataset_path from ...utils.fs import download, get_dataset_path
from ...utils.fs import download
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['BSDS500'] __all__ = ['BSDS500']
......
...@@ -11,8 +11,8 @@ import six ...@@ -11,8 +11,8 @@ import six
from six.moves import range from six.moves import range
import copy import copy
from ...utils import logger, get_dataset_path from ...utils import logger
from ...utils.fs import download from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['Cifar10', 'Cifar100'] __all__ = ['Cifar10', 'Cifar100']
......
...@@ -9,9 +9,9 @@ import six ...@@ -9,9 +9,9 @@ import six
import numpy as np import numpy as np
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from ...utils import logger, get_dataset_path from ...utils import logger
from ...utils.loadcaffe import get_caffe_pb from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download from ...utils.fs import mkdir_p, download, get_dataset_path
from ...utils.timer import timed_operation from ...utils.timer import timed_operation
from ..base import RNGDataFlow from ..base import RNGDataFlow
......
...@@ -8,8 +8,8 @@ import gzip ...@@ -8,8 +8,8 @@ import gzip
import numpy import numpy
from six.moves import range from six.moves import range
from ...utils import logger, get_dataset_path from ...utils import logger
from ...utils.fs import download from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['Mnist'] __all__ = ['Mnist']
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
import os import os
import numpy as np import numpy as np
from ...utils import logger, get_dataset_path from ...utils import logger
from ...utils.fs import get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['SVHNDigit'] __all__ = ['SVHNDigit']
......
...@@ -24,7 +24,7 @@ def _global_import(name): ...@@ -24,7 +24,7 @@ def _global_import(name):
_TO_IMPORT = set([ _TO_IMPORT = set([
'naming', 'naming',
'utils', 'utils',
'gpu' 'gpu' # TODO don't export it
]) ])
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
......
...@@ -8,8 +8,9 @@ import sys ...@@ -8,8 +8,9 @@ import sys
from six.moves import urllib from six.moves import urllib
import errno import errno
from . import logger from . import logger
from .utils import execute_only_once
__all__ = ['mkdir_p', 'download', 'recursive_walk'] __all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path']
def mkdir_p(dirname): def mkdir_p(dirname):
...@@ -67,5 +68,37 @@ def recursive_walk(rootdir): ...@@ -67,5 +68,37 @@ def recursive_walk(rootdir):
yield os.path.join(r, f) yield os.path.join(r, f)
def get_dataset_path(*args):
"""
Get the path to some dataset under ``$TENSORPACK_DATASET``.
Args:
args: strings to be joined to form path.
Returns:
str: path to the dataset.
"""
d = os.environ.get('TENSORPACK_DATASET', None)
if d is None:
old_d = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'dataflow', 'dataset'))
old_d_ret = os.path.join(d, *args)
new_d = os.path.expanduser('~/tensorpack_data')
if os.path.isdir(old_d_ret):
# there is an old dir containing data, use it for back-compat
logger.warn("You seem to have old data at {}. This is no longer \
the default location. You'll need to move it to {} \
(the new default location) or another directory set by \
$TENSORPACK_DATASET.".format(old_d, new_d))
d = new_d
if execute_only_once():
logger.warn("$TENSORPACK_DATASET not set, using {} for dataset.".format(d))
if not os.path.isdir(d):
mkdir_p(d)
logger.info("Created the directory {}.".format(d))
assert os.path.isdir(d), d
return os.path.join(d, *args)
if __name__ == '__main__': if __name__ == '__main__':
download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.') download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.')
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import os import os
from .utils import change_env from .utils import change_env
__all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus'] __all__ = ['change_gpu', 'get_nr_gpu']
def change_gpu(val): def change_gpu(val):
...@@ -28,13 +28,3 @@ def get_nr_gpu(): ...@@ -28,13 +28,3 @@ def get_nr_gpu():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None) env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO assert env is not None, 'gpu not set!' # TODO
return len(env.split(',')) return len(env.split(','))
def get_gpus():
"""
Returns:
list: a list of int of GPU id.
"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO
return map(int, env.strip().split(','))
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
import numpy as np import numpy as np
import os import os
from .utils import change_env, get_dataset_path from .utils import change_env
from .fs import download from .fs import download, get_dataset_path
from . import logger from . import logger
__all__ = ['load_caffe', 'get_caffe_pb'] __all__ = ['load_caffe', 'get_caffe_pb']
......
...@@ -10,9 +10,9 @@ from datetime import datetime ...@@ -10,9 +10,9 @@ from datetime import datetime
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
__all__ = ['change_env', __all__ = ['change_env',
'get_rng', 'get_rng',
'get_dataset_path',
'get_tqdm_kwargs', 'get_tqdm_kwargs',
'get_tqdm', 'get_tqdm',
'execute_only_once', 'execute_only_once',
...@@ -79,27 +79,6 @@ def execute_only_once(): ...@@ -79,27 +79,6 @@ def execute_only_once():
return True return True
def get_dataset_path(*args):
"""
Get the path to some dataset under ``$TENSORPACK_DATASET``.
Args:
args: strings to be joined to form path.
Returns:
str: path to the dataset.
"""
d = os.environ.get('TENSORPACK_DATASET', None)
if d is None:
d = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'dataflow', 'dataset'))
if execute_only_once():
from . import logger
logger.warn("TENSORPACK_DATASET not set, using {} for dataset.".format(d))
assert os.path.isdir(d), d
return os.path.join(d, *args)
def get_tqdm_kwargs(**kwargs): def get_tqdm_kwargs(**kwargs):
""" """
Return default arguments to be used with tqdm. Return default arguments to be used with tqdm.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment