Commit 6728b686 authored by Yuxin Wu's avatar Yuxin Wu

get_dataset_path instead of dir

parent 807296b3
......@@ -10,7 +10,7 @@ from collections import deque
import threading
import six
from six.moves import range
from ..utils import get_rng, logger, memoized, get_dataset_dir
from ..utils import get_rng, logger, memoized, get_dataset_path
from ..utils.stat import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace
......@@ -51,7 +51,7 @@ class AtariPlayer(RLEnvironment):
"""
super(AtariPlayer, self).__init__()
if not os.path.isfile(rom_file) and '/' not in rom_file:
rom_file = get_dataset_dir('atari_rom', rom_file)
rom_file = get_dataset_path('atari_rom', rom_file)
assert os.path.isfile(rom_file), \
"rom {} not found. Please download at {}".format(rom_file, ROM_URL)
......
......@@ -7,7 +7,7 @@ import os, glob
import cv2
import numpy as np
from ...utils import logger, get_rng, get_dataset_dir
from ...utils import logger, get_rng, get_dataset_path
from ...utils.fs import download
from ..base import RNGDataFlow
......@@ -40,7 +40,7 @@ class BSDS500(RNGDataFlow):
"""
# check and download data
if data_dir is None:
data_dir = get_dataset_dir('bsds500_data')
data_dir = get_dataset_path('bsds500_data')
if not os.path.isdir(os.path.join(data_dir, 'BSR')):
download(DATA_URL, data_dir)
filename = DATA_URL.split('/')[-1]
......
......@@ -13,7 +13,7 @@ from six.moves import urllib, range
import copy
import logging
from ...utils import logger, get_rng, get_dataset_dir
from ...utils import logger, get_rng, get_dataset_path
from ...utils.fs import download
from ..base import RNGDataFlow
......@@ -92,7 +92,7 @@ class CifarBase(RNGDataFlow):
assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum
if dir is None:
dir = get_dataset_dir('cifar{}_data'.format(cifar_classnum))
dir = get_dataset_path('cifar{}_data'.format(cifar_classnum))
maybe_download_and_extract(dir, self.cifar_classnum)
fnames = get_filenames(dir, cifar_classnum)
if train_or_test == 'train':
......
......@@ -9,7 +9,7 @@ import numpy as np
from six.moves import range
import xml.etree.ElementTree as ET
from ...utils import logger, get_rng, get_dataset_dir, memoized
from ...utils import logger, get_rng, get_dataset_path, memoized
from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download
from ...utils.timer import timed_operation
......@@ -28,7 +28,7 @@ class ILSVRCMeta(object):
"""
def __init__(self, dir=None):
if dir is None:
dir = get_dataset_dir('ilsvrc_metadata')
dir = get_dataset_path('ilsvrc_metadata')
self.dir = dir
mkdir_p(self.dir)
self.caffepb = get_caffe_pb()
......
......@@ -9,7 +9,7 @@ import random
import numpy
from six.moves import urllib, range
from ...utils import logger, get_dataset_dir
from ...utils import logger, get_dataset_path
from ...utils.fs import download
from ..base import RNGDataFlow
......@@ -103,7 +103,7 @@ class Mnist(RNGDataFlow):
train_or_test: string either 'train' or 'test'
"""
if dir is None:
dir = get_dataset_dir('mnist_data')
dir = get_dataset_path('mnist_data')
assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test
self.shuffle = shuffle
......
......@@ -8,7 +8,7 @@ import random
import numpy as np
from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir
from ...utils import logger, get_rng, get_dataset_path
from ..base import RNGDataFlow
try:
......@@ -38,7 +38,7 @@ class SVHNDigit(RNGDataFlow):
self.X, self.Y = SVHNDigit.Cache[name]
return
if data_dir is None:
data_dir = get_dataset_dir('svhn_data')
data_dir = get_dataset_path('svhn_data')
assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \
......
......@@ -11,7 +11,7 @@ import os
from six.moves import zip
from .utils import change_env, get_dataset_dir
from .utils import change_env, get_dataset_path
from .fs import download
from . import logger
......@@ -74,7 +74,7 @@ def load_caffe(model_desc, model_file):
return param_dict
def get_caffe_pb():
dir = get_dataset_dir('caffe')
dir = get_dataset_path('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file):
proto_path = download(CAFFE_PROTO_URL, dir)
......
......@@ -16,7 +16,7 @@ from . import logger
__all__ = ['change_env',
'map_arg',
'get_rng', 'memoized',
'get_dataset_dir',
'get_dataset_path',
'get_tqdm_kwargs'
]
......@@ -95,7 +95,7 @@ def get_rng(obj=None):
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed)
def get_dataset_dir(*args):
def get_dataset_path(*args):
d = os.environ.get('TENSORPACK_DATASET', None)
if d is None:
d = os.path.abspath(os.path.join(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment