Commit 4a59173c authored by Yuxin Wu's avatar Yuxin Wu

dataset dir

parent 194cda0b
...@@ -10,6 +10,7 @@ from scipy.io import loadmat ...@@ -10,6 +10,7 @@ from scipy.io import loadmat
from ...utils import logger, get_rng from ...utils import logger, get_rng
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['BSDS500'] __all__ = ['BSDS500']
...@@ -36,7 +37,7 @@ class BSDS500(DataFlow): ...@@ -36,7 +37,7 @@ class BSDS500(DataFlow):
""" """
# check and download data # check and download data
if data_dir is None: if data_dir is None:
data_dir = os.path.join(os.path.dirname(__file__), 'bsds500_data') data_dir = get_dataset_dir('bsds500_data')
if not os.path.isdir(os.path.join(data_dir, 'BSR')): if not os.path.isdir(os.path.join(data_dir, 'BSR')):
download(DATA_URL, data_dir) download(DATA_URL, data_dir)
filename = DATA_URL.split('/')[-1] filename = DATA_URL.split('/')[-1]
......
...@@ -16,6 +16,7 @@ import logging ...@@ -16,6 +16,7 @@ import logging
from ...utils import logger, get_rng from ...utils import logger, get_rng
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['Cifar10', 'Cifar100'] __all__ = ['Cifar10', 'Cifar100']
...@@ -92,8 +93,7 @@ class CifarBase(DataFlow): ...@@ -92,8 +93,7 @@ class CifarBase(DataFlow):
assert cifar_classnum == 10 or cifar_classnum == 100 assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum self.cifar_classnum = cifar_classnum
if dir is None: if dir is None:
dir = os.path.join(os.path.dirname(__file__), dir = get_dataset_dir('cifar{}_data'.format(cifar_classnum))
'cifar{}_data'.format(cifar_classnum))
maybe_download_and_extract(dir, self.cifar_classnum) maybe_download_and_extract(dir, self.cifar_classnum)
fnames = get_filenames(dir, cifar_classnum) fnames = get_filenames(dir, cifar_classnum)
if train_or_test == 'train': if train_or_test == 'train':
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
__all__ = ['get_dataset_dir']
def get_dataset_dir(name):
d = os.environ['TENSORPACK_DATASET']:
if d:
assert os.path.isdir(d)
else:
d = os.path.dirname(__file__)
return os.path.join(d, name)
...@@ -8,8 +8,9 @@ import cv2 ...@@ -8,8 +8,9 @@ import cv2
import numpy as np import numpy as np
from ...utils import logger, get_rng from ...utils import logger, get_rng
from ..base import DataFlow
from ...utils.fs import mkdir_p, download from ...utils.fs import mkdir_p, download
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['ILSVRCMeta', 'ILSVRC12'] __all__ = ['ILSVRCMeta', 'ILSVRC12']
...@@ -28,7 +29,7 @@ class ILSVRCMeta(object): ...@@ -28,7 +29,7 @@ class ILSVRCMeta(object):
""" """
def __init__(self, dir=None): def __init__(self, dir=None):
if dir is None: if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'ilsvrc_metadata') dir = get_dataset_dir('ilsvrc_metadata')
self.dir = dir self.dir = dir
mkdir_p(self.dir) mkdir_p(self.dir)
self.caffe_pb_file = os.path.join(self.dir, 'caffe_pb2.py') self.caffe_pb_file = os.path.join(self.dir, 'caffe_pb2.py')
...@@ -91,8 +92,7 @@ class ILSVRC12(DataFlow): ...@@ -91,8 +92,7 @@ class ILSVRC12(DataFlow):
name: 'train' or 'val' or 'test' name: 'train' or 'val' or 'test'
""" """
assert name in ['train', 'test', 'val'] assert name in ['train', 'test', 'val']
self.dir = dir self.full_dir = os.path.join(dir, name)
self.name = name
self.shuffle = shuffle self.shuffle = shuffle
self.meta = ILSVRCMeta(meta_dir) self.meta = ILSVRCMeta(meta_dir)
self.imglist = self.meta.get_image_list(name) self.imglist = self.meta.get_image_list(name)
...@@ -116,7 +116,7 @@ class ILSVRC12(DataFlow): ...@@ -116,7 +116,7 @@ class ILSVRC12(DataFlow):
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
tp = self.imglist[k] tp = self.imglist[k]
fname = os.path.join(self.dir, self.name, tp[0]).strip() fname = os.path.join(self.full_dir, tp[0]).strip()
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
if im.ndim == 2: if im.ndim == 2:
......
...@@ -12,6 +12,7 @@ from six.moves import urllib, range ...@@ -12,6 +12,7 @@ from six.moves import urllib, range
from ...utils import logger from ...utils import logger
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['Mnist'] __all__ = ['Mnist']
...@@ -103,7 +104,7 @@ class Mnist(DataFlow): ...@@ -103,7 +104,7 @@ class Mnist(DataFlow):
train_or_test: string either 'train' or 'test' train_or_test: string either 'train' or 'test'
""" """
if dir is None: if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'mnist_data') dir = get_dataset_dir('mnist_data')
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.shuffle = shuffle self.shuffle = shuffle
......
...@@ -12,6 +12,7 @@ from six.moves import range ...@@ -12,6 +12,7 @@ from six.moves import range
from ...utils import logger, get_rng from ...utils import logger, get_rng
from ..base import DataFlow from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['SVHNDigit'] __all__ = ['SVHNDigit']
...@@ -36,9 +37,7 @@ class SVHNDigit(DataFlow): ...@@ -36,9 +37,7 @@ class SVHNDigit(DataFlow):
self.X, self.Y = SVHNDigit.Cache[name] self.X, self.Y = SVHNDigit.Cache[name]
return return
if data_dir is None: if data_dir is None:
data_dir = os.path.join( data_dir = get_dataset_dir('svhn_data')
os.path.dirname(__file__),
'svhn_data')
assert name in ['train', 'test', 'extra'], name assert name in ['train', 'test', 'extra'], name
filename = os.path.join(data_dir, name + '_32x32.mat') filename = os.path.join(data_dir, name + '_32x32.mat')
assert os.path.isfile(filename), \ assert os.path.isfile(filename), \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment