Commit 2b4f7b14 authored by Yuxin Wu's avatar Yuxin Wu

play_n_episodes for gym submission

parent 4414d3ba
......@@ -28,7 +28,8 @@ from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
from tensorpack.RL import *
from simulator import *
import common
from common import (play_model, Evaluator, eval_model_multithread, play_one_episode)
from common import (play_model, Evaluator, eval_model_multithread,
play_one_episode, play_n_episodes)
if six.PY3:
from concurrent import futures
......@@ -238,18 +239,6 @@ def get_config():
)
def run_submission(cfg, output, nr):
player = get_player(train=False, dumpdir=output)
predfunc = OfflinePredictor(cfg)
logger.info("Start evaluation: ")
for k in range(nr):
if k != 0:
player.restart_episode()
score = play_one_episode(player, predfunc)
print("Score:", score)
# gym.upload(output, api_key='xxx')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
......@@ -283,7 +272,10 @@ if __name__ == '__main__':
elif args.task == 'eval':
eval_model_multithread(cfg, args.episode, get_player)
elif args.task == 'gen_submit':
run_submission(cfg, args.output, args.episode)
play_n_episodes(
get_player(train=False, dumpdir=args.output),
OfflinePredictor(cfg), args.episode)
# gym.upload(output, api_key='xxx')
else:
nr_gpu = get_nr_gpu()
if nr_gpu > 0:
......
......@@ -62,6 +62,9 @@ def get_player(viz=False, train=False):
class Model(DQNModel):
def __init__(self):
super(Model, self).__init__(IMAGE_SIZE, CHANNEL, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image):
""" image: [0,255]"""
image = image / 255.0
......@@ -95,7 +98,7 @@ class Model(DQNModel):
def get_config():
logger.auto_set_dir()
M = Model(IMAGE_SIZE, CHANNEL, METHOD, NUM_ACTIONS, GAMMA)
M = Model()
expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True),
......
......@@ -93,4 +93,3 @@ class Model(ModelDesc):
logger.info("{} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network')
......@@ -112,3 +112,12 @@ class Evaluator(Triggerable):
self.eval_episode = int(self.eval_episode * 0.94)
self.trainer.monitors.put('mean_score', mean)
self.trainer.monitors.put('max_score', max)
def play_n_episodes(player, predfunc, nr):
logger.info("Start evaluation: ")
for k in range(nr):
if k != 0:
player.restart_episode()
score = play_one_episode(player, predfunc)
print("Score:", score)
......@@ -5,5 +5,6 @@ cd examples
GIT_ARG="--git-dir ../.git --work-tree .."
# find out modified python files
MOD=$(git $GIT_ARG status -s | grep -E '\.py$' | grep -E '^\b+M\b+' | cut -c 4-)
MOD=$(git $GIT_ARG status -s | grep -E '\.py$' | grep -E '^ *M|^ *A ' | cut -c 4-)
# git $GIT_ARG status -s | grep -E '\.py$'
flake8 $MOD
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment