Commit e7ede3eb authored by Yuxin Wu's avatar Yuxin Wu

dataset dir

parent 4a59173c
......@@ -9,7 +9,7 @@ import cv2
from collections import deque
import six
from six.moves import range
from ..utils import get_rng, logger, memoized
from ..utils import get_rng, logger, memoized, get_dataset_dir
from ..utils.stat import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace
......@@ -46,6 +46,10 @@ class AtariPlayer(RLEnvironment):
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training.
"""
super(AtariPlayer, self).__init__()
if not os.path.isfile(rom_file) and '/' not in rom_file:
rom_file = os.path.join(get_dataset_dir('atari_rom'), rom_file)
assert os.path.isfile(rom_file), "rom {} not found".format(rom_file)
self.ale = ALEInterface()
self.rng = get_rng(self)
......
......@@ -7,10 +7,10 @@ import os, glob
import cv2
import numpy as np
from scipy.io import loadmat
from ...utils import logger, get_rng
from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import download
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['BSDS500']
......
......@@ -13,10 +13,9 @@ from six.moves import urllib, range
import copy
import logging
from ...utils import logger, get_rng
from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import download
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['Cifar10', 'Cifar100']
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
__all__ = ['get_dataset_dir']
def get_dataset_dir(name):
d = os.environ['TENSORPACK_DATASET']:
if d:
assert os.path.isdir(d)
else:
d = os.path.dirname(__file__)
return os.path.join(d, name)
......@@ -7,10 +7,9 @@ import tarfile
import cv2
import numpy as np
from ...utils import logger, get_rng
from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import mkdir_p, download
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['ILSVRCMeta', 'ILSVRC12']
......
......@@ -9,10 +9,9 @@ import random
import numpy
from six.moves import urllib, range
from ...utils import logger
from ...utils import logger, get_dataset_dir
from ...utils.fs import download
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['Mnist']
......
......@@ -10,9 +10,8 @@ import scipy
import scipy.io
from six.moves import range
from ...utils import logger, get_rng
from ...utils import logger, get_rng, get_dataset_dir
from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['SVHNDigit']
......
......@@ -12,7 +12,10 @@ import numpy as np
from . import logger
__all__ = ['change_env',
'get_rng', 'memoized', 'get_nr_gpu', 'get_gpus']
'get_rng', 'memoized',
'get_nr_gpu',
'get_gpus',
'get_dataset_dir']
#def expand_dim_if_necessary(var, dp):
# """
......@@ -73,11 +76,20 @@ def get_rng(self):
return np.random.RandomState(seed)
def get_nr_gpu():
env = os.environ['CUDA_VISIBLE_DEVICES']
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO
return len(env.split(','))
def get_gpus():
env = os.environ['CUDA_VISIBLE_DEVICES']
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO
return map(int, env.strip().split(','))
def get_dataset_dir(name):
d = os.environ.get('TENSORPACK_DATASET', None)
if d:
assert os.path.isdir(d)
else:
d = os.path.dirname(__file__)
return os.path.join(d, name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment