Commit 2b4f7b14 authored by Yuxin Wu's avatar Yuxin Wu

play_n_episodes for gym submission

parent 4414d3ba
...@@ -28,7 +28,8 @@ from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient ...@@ -28,7 +28,8 @@ from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
from tensorpack.RL import * from tensorpack.RL import *
from simulator import * from simulator import *
import common import common
from common import (play_model, Evaluator, eval_model_multithread, play_one_episode) from common import (play_model, Evaluator, eval_model_multithread,
play_one_episode, play_n_episodes)
if six.PY3: if six.PY3:
from concurrent import futures from concurrent import futures
...@@ -238,18 +239,6 @@ def get_config(): ...@@ -238,18 +239,6 @@ def get_config():
) )
def run_submission(cfg, output, nr):
player = get_player(train=False, dumpdir=output)
predfunc = OfflinePredictor(cfg)
logger.info("Start evaluation: ")
for k in range(nr):
if k != 0:
player.restart_episode()
score = play_one_episode(player, predfunc)
print("Score:", score)
# gym.upload(output, api_key='xxx')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
...@@ -283,7 +272,10 @@ if __name__ == '__main__': ...@@ -283,7 +272,10 @@ if __name__ == '__main__':
elif args.task == 'eval': elif args.task == 'eval':
eval_model_multithread(cfg, args.episode, get_player) eval_model_multithread(cfg, args.episode, get_player)
elif args.task == 'gen_submit': elif args.task == 'gen_submit':
run_submission(cfg, args.output, args.episode) play_n_episodes(
get_player(train=False, dumpdir=args.output),
OfflinePredictor(cfg), args.episode)
# gym.upload(output, api_key='xxx')
else: else:
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
if nr_gpu > 0: if nr_gpu > 0:
......
...@@ -62,6 +62,9 @@ def get_player(viz=False, train=False): ...@@ -62,6 +62,9 @@ def get_player(viz=False, train=False):
class Model(DQNModel): class Model(DQNModel):
def __init__(self):
super(Model, self).__init__(IMAGE_SIZE, CHANNEL, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
""" image: [0,255]""" """ image: [0,255]"""
image = image / 255.0 image = image / 255.0
...@@ -95,7 +98,7 @@ class Model(DQNModel): ...@@ -95,7 +98,7 @@ class Model(DQNModel):
def get_config(): def get_config():
logger.auto_set_dir() logger.auto_set_dir()
M = Model(IMAGE_SIZE, CHANNEL, METHOD, NUM_ACTIONS, GAMMA) M = Model()
expreplay = ExpReplay( expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']), predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True), player=get_player(train=True),
......
...@@ -93,4 +93,3 @@ class Model(ModelDesc): ...@@ -93,4 +93,3 @@ class Model(ModelDesc):
logger.info("{} <- {}".format(target_name, new_name)) logger.info("{} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0'))) ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network') return tf.group(*ops, name='update_target_network')
...@@ -112,3 +112,12 @@ class Evaluator(Triggerable): ...@@ -112,3 +112,12 @@ class Evaluator(Triggerable):
self.eval_episode = int(self.eval_episode * 0.94) self.eval_episode = int(self.eval_episode * 0.94)
self.trainer.monitors.put('mean_score', mean) self.trainer.monitors.put('mean_score', mean)
self.trainer.monitors.put('max_score', max) self.trainer.monitors.put('max_score', max)
def play_n_episodes(player, predfunc, nr):
logger.info("Start evaluation: ")
for k in range(nr):
if k != 0:
player.restart_episode()
score = play_one_episode(player, predfunc)
print("Score:", score)
...@@ -5,5 +5,6 @@ cd examples ...@@ -5,5 +5,6 @@ cd examples
GIT_ARG="--git-dir ../.git --work-tree .." GIT_ARG="--git-dir ../.git --work-tree .."
# find out modified python files # find out modified python files
MOD=$(git $GIT_ARG status -s | grep -E '\.py$' | grep -E '^\b+M\b+' | cut -c 4-) MOD=$(git $GIT_ARG status -s | grep -E '\.py$' | grep -E '^ *M|^ *A ' | cut -c 4-)
# git $GIT_ARG status -s | grep -E '\.py$'
flake8 $MOD flake8 $MOD
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment