Commit 1ed89bb6 authored by Yuxin Wu's avatar Yuxin Wu

get_dataset_dir args

parent 51535c94
...@@ -62,3 +62,7 @@ target/ ...@@ -62,3 +62,7 @@ target/
*.dat *.dat
*.bin *.bin
*.tfmodel
*.meta
*.log*
model-*
...@@ -21,8 +21,7 @@ IMAGE_SIZE = 28 ...@@ -21,8 +21,7 @@ IMAGE_SIZE = 28
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'), return [InputVar(tf.float32, (None, IMAGE_SIZE, IMAGE_SIZE), 'input'),
InputVar(tf.int32, (None,), 'label') InputVar(tf.int32, (None,), 'label') ]
]
def _build_graph(self, input_vars, is_training): def _build_graph(self, input_vars, is_training):
is_training = bool(is_training) is_training = bool(is_training)
......
...@@ -51,7 +51,7 @@ class AtariPlayer(RLEnvironment): ...@@ -51,7 +51,7 @@ class AtariPlayer(RLEnvironment):
""" """
super(AtariPlayer, self).__init__() super(AtariPlayer, self).__init__()
if not os.path.isfile(rom_file) and '/' not in rom_file: if not os.path.isfile(rom_file) and '/' not in rom_file:
rom_file = os.path.join(get_dataset_dir('atari_rom'), rom_file) rom_file = get_dataset_dir('atari_rom', rom_file)
assert os.path.isfile(rom_file), \ assert os.path.isfile(rom_file), \
"rom {} not found. Please download at {}".format(rom_file, ROM_URL) "rom {} not found. Please download at {}".format(rom_file, ROM_URL)
......
...@@ -86,13 +86,12 @@ def get_gpus(): ...@@ -86,13 +86,12 @@ def get_gpus():
assert env is not None # TODO assert env is not None # TODO
return map(int, env.strip().split(',')) return map(int, env.strip().split(','))
def get_dataset_dir(name): def get_dataset_dir(*args):
d = os.environ.get('TENSORPACK_DATASET', None) d = os.environ.get('TENSORPACK_DATASET', None)
if d: if d is None:
assert os.path.isdir(d), d
else:
d = os.path.abspath(os.path.join( d = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'dataflow', 'dataset')) os.path.dirname(__file__), '..', 'dataflow', 'dataset'))
logger.info("TENSORPACK_DATASET not set, using {} for dataset.".format(d)) logger.info("TENSORPACK_DATASET not set, using {} for dataset.".format(d))
return os.path.join(d, name) assert os.path.isdir(d), d
return os.path.join(d, *args)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment