Commit 6728b686 authored by Yuxin Wu's avatar Yuxin Wu

get_dataset_path instead of dir

parent 807296b3
...@@ -10,7 +10,7 @@ from collections import deque ...@@ -10,7 +10,7 @@ from collections import deque
import threading import threading
import six import six
from six.moves import range from six.moves import range
from ..utils import get_rng, logger, memoized, get_dataset_dir from ..utils import get_rng, logger, memoized, get_dataset_path
from ..utils.stat import StatCounter from ..utils.stat import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace from .envbase import RLEnvironment, DiscreteActionSpace
...@@ -51,7 +51,7 @@ class AtariPlayer(RLEnvironment): ...@@ -51,7 +51,7 @@ class AtariPlayer(RLEnvironment):
""" """
super(AtariPlayer, self).__init__() super(AtariPlayer, self).__init__()
if not os.path.isfile(rom_file) and '/' not in rom_file: if not os.path.isfile(rom_file) and '/' not in rom_file:
rom_file = get_dataset_dir('atari_rom', rom_file) rom_file = get_dataset_path('atari_rom', rom_file)
assert os.path.isfile(rom_file), \ assert os.path.isfile(rom_file), \
"rom {} not found. Please download at {}".format(rom_file, ROM_URL) "rom {} not found. Please download at {}".format(rom_file, ROM_URL)
......
...@@ -7,7 +7,7 @@ import os, glob ...@@ -7,7 +7,7 @@ import os, glob
import cv2 import cv2
import numpy as np import numpy as np
from ...utils import logger, get_rng, get_dataset_dir from ...utils import logger, get_rng, get_dataset_path
from ...utils.fs import download from ...utils.fs import download
from ..base import RNGDataFlow from ..base import RNGDataFlow
...@@ -40,7 +40,7 @@ class BSDS500(RNGDataFlow): ...@@ -40,7 +40,7 @@ class BSDS500(RNGDataFlow):
""" """
# check and download data # check and download data
if data_dir is None: if data_dir is None:
data_dir = get_dataset_dir('bsds500_data') data_dir = get_dataset_path('bsds500_data')
if not os.path.isdir(os.path.join(data_dir, 'BSR')): if not os.path.isdir(os.path.join(data_dir, 'BSR')):
download(DATA_URL, data_dir) download(DATA_URL, data_dir)
filename = DATA_URL.split('/')[-1] filename = DATA_URL.split('/')[-1]
......
...@@ -13,7 +13,7 @@ from six.moves import urllib, range ...@@ -13,7 +13,7 @@ from six.moves import urllib, range
import copy import copy
import logging import logging
from ...utils import logger, get_rng, get_dataset_dir from ...utils import logger, get_rng, get_dataset_path
from ...utils.fs import download from ...utils.fs import download
from ..base import RNGDataFlow from ..base import RNGDataFlow
...@@ -92,7 +92,7 @@ class CifarBase(RNGDataFlow): ...@@ -92,7 +92,7 @@ class CifarBase(RNGDataFlow):
assert cifar_classnum == 10 or cifar_classnum == 100 assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum self.cifar_classnum = cifar_classnum
if dir is None: if dir is None:
dir = get_dataset_dir('cifar{}_data'.format(cifar_classnum)) dir = get_dataset_path('cifar{}_data'.format(cifar_classnum))
maybe_download_and_extract(dir, self.cifar_classnum) maybe_download_and_extract(dir, self.cifar_classnum)
fnames = get_filenames(dir, cifar_classnum) fnames = get_filenames(dir, cifar_classnum)
if train_or_test == 'train': if train_or_test == 'train':
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from six.moves import range from six.moves import range
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from ...utils import logger, get_rng, get_dataset_dir, memoized from ...utils import logger, get_rng, get_dataset_path, memoized
from ...utils.loadcaffe import get_caffe_pb from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download from ...utils.fs import mkdir_p, download
from ...utils.timer import timed_operation from ...utils.timer import timed_operation
...@@ -28,7 +28,7 @@ class ILSVRCMeta(object): ...@@ -28,7 +28,7 @@ class ILSVRCMeta(object):
""" """
def __init__(self, dir=None): def __init__(self, dir=None):
if dir is None: if dir is None:
dir = get_dataset_dir('ilsvrc_metadata') dir = get_dataset_path('ilsvrc_metadata')
self.dir = dir self.dir = dir
mkdir_p(self.dir) mkdir_p(self.dir)
self.caffepb = get_caffe_pb() self.caffepb = get_caffe_pb()
......
...@@ -9,7 +9,7 @@ import random ...@@ -9,7 +9,7 @@ import random
import numpy import numpy
from six.moves import urllib, range from six.moves import urllib, range
from ...utils import logger, get_dataset_dir from ...utils import logger, get_dataset_path
from ...utils.fs import download from ...utils.fs import download
from ..base import RNGDataFlow from ..base import RNGDataFlow
...@@ -103,7 +103,7 @@ class Mnist(RNGDataFlow): ...@@ -103,7 +103,7 @@ class Mnist(RNGDataFlow):
train_or_test: string either 'train' or 'test' train_or_test: string either 'train' or 'test'
""" """
if dir is None: if dir is None:
dir = get_dataset_dir('mnist_data') dir = get_dataset_path('mnist_data')
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.shuffle = shuffle self.shuffle = shuffle
......
...@@ -8,7 +8,7 @@ import random ...@@ -8,7 +8,7 @@ import random
import numpy as np import numpy as np
from six.moves import range from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir from ...utils import logger, get_rng, get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
try: try:
...@@ -38,7 +38,7 @@ class SVHNDigit(RNGDataFlow): ...@@ -38,7 +38,7 @@ class SVHNDigit(RNGDataFlow):
self.X, self.Y = SVHNDigit.Cache[name] self.X, self.Y = SVHNDigit.Cache[name]
return return
if data_dir is None: if data_dir is None:
data_dir = get_dataset_dir('svhn_data') data_dir = get_dataset_path('svhn_data')
assert name in ['train', 'test', 'extra'], name assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat') filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \ assert os.path.isfile(filename), \
......
...@@ -11,7 +11,7 @@ import os ...@@ -11,7 +11,7 @@ import os
from six.moves import zip from six.moves import zip
from .utils import change_env, get_dataset_dir from .utils import change_env, get_dataset_path
from .fs import download from .fs import download
from . import logger from . import logger
...@@ -74,7 +74,7 @@ def load_caffe(model_desc, model_file): ...@@ -74,7 +74,7 @@ def load_caffe(model_desc, model_file):
return param_dict return param_dict
def get_caffe_pb(): def get_caffe_pb():
dir = get_dataset_dir('caffe') dir = get_dataset_path('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py') caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file): if not os.path.isfile(caffe_pb_file):
proto_path = download(CAFFE_PROTO_URL, dir) proto_path = download(CAFFE_PROTO_URL, dir)
......
...@@ -16,7 +16,7 @@ from . import logger ...@@ -16,7 +16,7 @@ from . import logger
__all__ = ['change_env', __all__ = ['change_env',
'map_arg', 'map_arg',
'get_rng', 'memoized', 'get_rng', 'memoized',
'get_dataset_dir', 'get_dataset_path',
'get_tqdm_kwargs' 'get_tqdm_kwargs'
] ]
...@@ -95,7 +95,7 @@ def get_rng(obj=None): ...@@ -95,7 +95,7 @@ def get_rng(obj=None):
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295 int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed) return np.random.RandomState(seed)
def get_dataset_dir(*args): def get_dataset_path(*args):
d = os.environ.get('TENSORPACK_DATASET', None) d = os.environ.get('TENSORPACK_DATASET', None)
if d is None: if d is None:
d = os.path.abspath(os.path.join( d = os.path.abspath(os.path.join(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment