Commit 1ed89bb6 authored by Yuxin Wu's avatar Yuxin Wu

get_dataset_dir args

parent 51535c94
......@@ -62,3 +62,7 @@ target/
*.dat
*.bin
*.tfmodel
*.meta
*.log*
model-*
......@@ -21,8 +21,7 @@ IMAGE_SIZE = 28
class Model(ModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputVar(tf.int32, (None,), 'label')
]
InputVar(tf.int32, (None,), 'label') ]
def _build_graph(self, input_vars, is_training):
is_training = bool(is_training)
......
......@@ -51,7 +51,7 @@ class AtariPlayer(RLEnvironment):
"""
super(AtariPlayer, self).__init__()
if not os.path.isfile(rom_file) and '/' not in rom_file:
rom_file = os.path.join(get_dataset_dir('atari_rom'), rom_file)
rom_file = get_dataset_dir('atari_rom', rom_file)
assert os.path.isfile(rom_file), \
"rom {} not found. Please download at {}".format(rom_file, ROM_URL)
......
......@@ -86,13 +86,12 @@ def get_gpus():
assert env is not None # TODO
return map(int, env.strip().split(','))
def get_dataset_dir(name):
def get_dataset_dir(*args):
d = os.environ.get('TENSORPACK_DATASET', None)
if d:
assert os.path.isdir(d), d
else:
if d is None:
d = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'dataflow', 'dataset'))
logger.info("TENSORPACK_DATASET not set, using {} for dataset.".format(d))
return os.path.join(d, name)
assert os.path.isdir(d), d
return os.path.join(d, *args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment