Commit 4fd1db97 authored by Yuxin Wu's avatar Yuxin Wu

visualize player

parent 4d33715d
...@@ -10,6 +10,7 @@ import random ...@@ -10,6 +10,7 @@ import random
import argparse import argparse
from tqdm import tqdm from tqdm import tqdm
import multiprocessing import multiprocessing
from collections import deque
from tensorpack import * from tensorpack import *
from tensorpack.models import * from tensorpack.models import *
...@@ -153,6 +154,7 @@ def play_model(model_path, romfile): ...@@ -153,6 +154,7 @@ def play_model(model_path, romfile):
output_var_names=['fct/output:0']) output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg) predfunc = get_predict_func(cfg)
tot_reward = 0 tot_reward = 0
que = deque(maxlen=30)
while True: while True:
s = player.current_state() s = player.current_state()
outputs = predfunc([[s]]) outputs = predfunc([[s]])
...@@ -161,13 +163,16 @@ def play_model(model_path, romfile): ...@@ -161,13 +163,16 @@ def play_model(model_path, romfile):
print action_value, act print action_value, act
if random.random() < 0.01: if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions())) act = random.choice(range(player.driver.get_num_actions()))
if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen:
act = 1
que.append(act)
print(act) print(act)
_, reward, isOver = player.action(act) _, reward, isOver = player.action(act)
tot_reward += reward tot_reward += reward
if isOver: if isOver:
print("Total:", tot_reward) print("Total:", tot_reward)
tot_reward = 0 tot_reward = 0
pbar.update()
def eval_model_multiprocess(model_path, romfile): def eval_model_multiprocess(model_path, romfile):
M = Model() M = Model()
...@@ -191,6 +196,7 @@ def eval_model_multiprocess(model_path, romfile): ...@@ -191,6 +196,7 @@ def eval_model_multiprocess(model_path, romfile):
self._init_runtime() self._init_runtime()
tot_reward = 0 tot_reward = 0
que = deque(maxlen=30)
while True: while True:
s = player.current_state() s = player.current_state()
outputs = self.func([[s]]) outputs = self.func([[s]])
...@@ -199,6 +205,10 @@ def eval_model_multiprocess(model_path, romfile): ...@@ -199,6 +205,10 @@ def eval_model_multiprocess(model_path, romfile):
#print action_value, act #print action_value, act
if random.random() < 0.01: if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions())) act = random.choice(range(player.driver.get_num_actions()))
if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen:
act = 1
que.append(act)
#print(act) #print(act)
_, reward, isOver = player.action(act) _, reward, isOver = player.action(act)
tot_reward += reward tot_reward += reward
...@@ -215,16 +225,14 @@ def eval_model_multiprocess(model_path, romfile): ...@@ -215,16 +225,14 @@ def eval_model_multiprocess(model_path, romfile):
for k in procs: for k in procs:
k.start() k.start()
stat = StatCounter() stat = StatCounter()
EVAL_EPISODE = 50 try:
with tqdm(total=EVAL_EPISODE) as pbar: EVAL_EPISODE = 50
while True: for _ in tqdm(range(EVAL_EPISODE)):
r = q.get() r = q.get()
stat.feed(r) stat.feed(r)
pbar.update() finally:
if stat.count() == EVAL_EPISODE: logger.info("Average Score: {}. Max Score: {}".format(
logger.info("Average Score: {}. Max Score: {}".format( stat.average, stat.max))
stat.average, stat.max))
break
def get_config(romfile): def get_config(romfile):
......
...@@ -33,12 +33,13 @@ class AtariDriver(object): ...@@ -33,12 +33,13 @@ class AtariDriver(object):
self.viz = viz self.viz = viz
self.romname = os.path.basename(rom_file) self.romname = os.path.basename(rom_file)
if self.viz: if isinstance(self.viz, float):
cv2.startWindowThread() cv2.startWindowThread()
cv2.namedWindow(self.romname) cv2.namedWindow(self.romname)
self._reset() self._reset()
self.last_image = self._grab_raw_image() self.last_image = self._grab_raw_image()
self.framenum = 0
def _grab_raw_image(self): def _grab_raw_image(self):
""" """
...@@ -55,9 +56,12 @@ class AtariDriver(object): ...@@ -55,9 +56,12 @@ class AtariDriver(object):
now = self._grab_raw_image() now = self._grab_raw_image()
ret = np.maximum(now, self.last_image) ret = np.maximum(now, self.last_image)
self.last_image = now self.last_image = now
if self.viz: if isinstance(self.viz, float):
cv2.imshow(self.romname, ret) cv2.imshow(self.romname, ret)
time.sleep(self.viz) time.sleep(self.viz)
else:
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret)
self.framenum += 1
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0] ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
return ret return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment