Commit e7ede3eb authored by Yuxin Wu's avatar Yuxin Wu

dataset dir

parent 4a59173c
...@@ -9,7 +9,7 @@ import cv2 ...@@ -9,7 +9,7 @@ import cv2
from collections import deque from collections import deque
import six import six
from six.moves import range from six.moves import range
from ..utils import get_rng, logger, memoized from ..utils import get_rng, logger, memoized, get_dataset_dir
from ..utils.stat import StatCounter from ..utils.stat import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace from .envbase import RLEnvironment, DiscreteActionSpace
...@@ -46,6 +46,10 @@ class AtariPlayer(RLEnvironment): ...@@ -46,6 +46,10 @@ class AtariPlayer(RLEnvironment):
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training. :param live_losts_as_eoe: consider lost of lives as end of episode. useful for training.
""" """
super(AtariPlayer, self).__init__() super(AtariPlayer, self).__init__()
if not os.path.isfile(rom_file) and '/' not in rom_file:
rom_file = os.path.join(get_dataset_dir('atari_rom'), rom_file)
assert os.path.isfile(rom_file), "rom {} not found".format(rom_file)
self.ale = ALEInterface() self.ale = ALEInterface()
self.rng = get_rng(self) self.rng = get_rng(self)
......
...@@ -7,10 +7,10 @@ import os, glob ...@@ -7,10 +7,10 @@ import os, glob
import cv2 import cv2
import numpy as np import numpy as np
from scipy.io import loadmat from scipy.io import loadmat
from ...utils import logger, get_rng
from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['BSDS500'] __all__ = ['BSDS500']
......
...@@ -13,10 +13,9 @@ from six.moves import urllib, range ...@@ -13,10 +13,9 @@ from six.moves import urllib, range
import copy import copy
import logging import logging
from ...utils import logger, get_rng from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['Cifar10', 'Cifar100'] __all__ = ['Cifar10', 'Cifar100']
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
__all__ = ['get_dataset_dir']
def get_dataset_dir(name):
d = os.environ['TENSORPACK_DATASET']:
if d:
assert os.path.isdir(d)
else:
d = os.path.dirname(__file__)
return os.path.join(d, name)
...@@ -7,10 +7,9 @@ import tarfile ...@@ -7,10 +7,9 @@ import tarfile
import cv2 import cv2
import numpy as np import numpy as np
from ...utils import logger, get_rng from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import mkdir_p, download from ...utils.fs import mkdir_p, download
from ..base import DataFlow from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['ILSVRCMeta', 'ILSVRC12'] __all__ = ['ILSVRCMeta', 'ILSVRC12']
......
...@@ -9,10 +9,9 @@ import random ...@@ -9,10 +9,9 @@ import random
import numpy import numpy
from six.moves import urllib, range from six.moves import urllib, range
from ...utils import logger from ...utils import logger, get_dataset_dir
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['Mnist'] __all__ = ['Mnist']
......
...@@ -10,9 +10,8 @@ import scipy ...@@ -10,9 +10,8 @@ import scipy
import scipy.io import scipy.io
from six.moves import range from six.moves import range
from ...utils import logger, get_rng from ...utils import logger, get_rng, get_dataset_dir
from ..base import DataFlow from ..base import DataFlow
from .common import get_dataset_dir
__all__ = ['SVHNDigit'] __all__ = ['SVHNDigit']
......
...@@ -12,7 +12,10 @@ import numpy as np ...@@ -12,7 +12,10 @@ import numpy as np
from . import logger from . import logger
__all__ = ['change_env', __all__ = ['change_env',
'get_rng', 'memoized', 'get_nr_gpu', 'get_gpus'] 'get_rng', 'memoized',
'get_nr_gpu',
'get_gpus',
'get_dataset_dir']
#def expand_dim_if_necessary(var, dp): #def expand_dim_if_necessary(var, dp):
# """ # """
...@@ -73,11 +76,20 @@ def get_rng(self): ...@@ -73,11 +76,20 @@ def get_rng(self):
return np.random.RandomState(seed) return np.random.RandomState(seed)
def get_nr_gpu(): def get_nr_gpu():
env = os.environ['CUDA_VISIBLE_DEVICES'] env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO assert env is not None # TODO
return len(env.split(',')) return len(env.split(','))
def get_gpus(): def get_gpus():
env = os.environ['CUDA_VISIBLE_DEVICES'] env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO assert env is not None # TODO
return map(int, env.strip().split(',')) return map(int, env.strip().split(','))
def get_dataset_dir(name):
d = os.environ.get('TENSORPACK_DATASET', None)
if d:
assert os.path.isdir(d)
else:
d = os.path.dirname(__file__)
return os.path.join(d, name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment