Commit 4a59173c authored by Yuxin Wu's avatar Yuxin Wu

dataset dir

parent 194cda0b
......@@ -10,6 +10,7 @@ from scipy.io import loadmat
from ...utils import logger, get_rng
from ...utils.fs import download
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['BSDS500']
......@@ -36,7 +37,7 @@ class BSDS500(DataFlow):
"""
# check and download data
if data_dir is None:
data_dir = os.path.join(os.path.dirname(__file__), 'bsds500_data')
data_dir = get_dataset_dir('bsds500_data')
if not os.path.isdir(os.path.join(data_dir, 'BSR')):
download(DATA_URL, data_dir)
filename = DATA_URL.split('/')[-1]
......
......@@ -16,6 +16,7 @@ import logging
from ...utils import logger, get_rng
from ...utils.fs import download
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['Cifar10', 'Cifar100']
......@@ -92,8 +93,7 @@ class CifarBase(DataFlow):
assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum
if dir is None:
dir = os.path.join(os.path.dirname(__file__),
'cifar{}_data'.format(cifar_classnum))
dir = get_dataset_dir('cifar{}_data'.format(cifar_classnum))
maybe_download_and_extract(dir, self.cifar_classnum)
fnames = get_filenames(dir, cifar_classnum)
if train_or_test == 'train':
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
__all__ = ['get_dataset_dir']
def get_dataset_dir(name):
d = os.environ['TENSORPACK_DATASET']:
if d:
assert os.path.isdir(d)
else:
d = os.path.dirname(__file__)
return os.path.join(d, name)
......@@ -8,8 +8,9 @@ import cv2
import numpy as np
from ...utils import logger, get_rng
from ..base import DataFlow
from ...utils.fs import mkdir_p, download
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['ILSVRCMeta', 'ILSVRC12']
......@@ -28,7 +29,7 @@ class ILSVRCMeta(object):
"""
def __init__(self, dir=None):
if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'ilsvrc_metadata')
dir = get_dataset_dir('ilsvrc_metadata')
self.dir = dir
mkdir_p(self.dir)
self.caffe_pb_file = os.path.join(self.dir, 'caffe_pb2.py')
......@@ -91,8 +92,7 @@ class ILSVRC12(DataFlow):
name: 'train' or 'val' or 'test'
"""
assert name in ['train', 'test', 'val']
self.dir = dir
self.name = name
self.full_dir = os.path.join(dir, name)
self.shuffle = shuffle
self.meta = ILSVRCMeta(meta_dir)
self.imglist = self.meta.get_image_list(name)
......@@ -116,7 +116,7 @@ class ILSVRC12(DataFlow):
self.rng.shuffle(idxs)
for k in idxs:
tp = self.imglist[k]
fname = os.path.join(self.dir, self.name, tp[0]).strip()
fname = os.path.join(self.full_dir, tp[0]).strip()
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
if im.ndim == 2:
......
......@@ -12,6 +12,7 @@ from six.moves import urllib, range
from ...utils import logger
from ...utils.fs import download
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['Mnist']
......@@ -103,7 +104,7 @@ class Mnist(DataFlow):
train_or_test: string either 'train' or 'test'
"""
if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'mnist_data')
dir = get_dataset_dir('mnist_data')
assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test
self.shuffle = shuffle
......
......@@ -12,6 +12,7 @@ from six.moves import range
from ...utils import logger, get_rng
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['SVHNDigit']
......@@ -36,9 +37,7 @@ class SVHNDigit(DataFlow):
self.X, self.Y = SVHNDigit.Cache[name]
return
if data_dir is None:
data_dir = os.path.join(
os.path.dirname(__file__),
'svhn_data')
data_dir = get_dataset_dir('svhn_data')
assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment