Commit 21a6984c authored by Yuxin Wu's avatar Yuxin Wu

[a3c] specify dir to save train logs

parent 963e5100
...@@ -215,10 +215,6 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -215,10 +215,6 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def train(): def train():
assert tf.test.is_gpu_available(), "Training requires GPUs!"
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
# assign GPUs for training & inference # assign GPUs for training & inference
num_gpu = get_num_gpu() num_gpu = get_num_gpu()
global PREDICTOR_THREAD global PREDICTOR_THREAD
...@@ -275,9 +271,11 @@ if __name__ == '__main__': ...@@ -275,9 +271,11 @@ if __name__ == '__main__':
parser.add_argument('--env', help='env', required=True) parser.add_argument('--env', help='env', required=True)
parser.add_argument('--task', help='task to perform', parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train', 'dump_video'], default='train') choices=['play', 'eval', 'train', 'dump_video'], default='train')
parser.add_argument('--output', help='output directory for submission', default='output_dir') parser.add_argument('--output', help='output directory for logs and videos')
parser.add_argument('--episode', help='number of episode to eval', default=100, type=int) parser.add_argument('--episode', help='number of episode to eval', default=100, type=int)
args = parser.parse_args() args = parser.parse_args()
if args.output is None:
args.output = os.path.join('train_log', 'train-atari-{}'.format(args.env))
ENV_NAME = args.env ENV_NAME = args.env
NUM_ACTIONS = get_player().action_space.n NUM_ACTIONS = get_player().action_space.n
...@@ -303,4 +301,6 @@ if __name__ == '__main__': ...@@ -303,4 +301,6 @@ if __name__ == '__main__':
get_player(train=False, dumpdir=args.output), get_player(train=False, dumpdir=args.output),
pred, args.episode) pred, args.episode)
else: else:
assert tf.test.is_gpu_available(), "Training requires GPUs!"
logger.set_logger_dir(args.output)
train() train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment