Commit 4fd1db97 authored by Yuxin Wu's avatar Yuxin Wu

visualize player

parent 4d33715d
......@@ -10,6 +10,7 @@ import random
import argparse
from tqdm import tqdm
import multiprocessing
from collections import deque
from tensorpack import *
from tensorpack.models import *
......@@ -153,6 +154,7 @@ def play_model(model_path, romfile):
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg)
tot_reward = 0
que = deque(maxlen=30)
while True:
s = player.current_state()
outputs = predfunc([[s]])
......@@ -161,13 +163,16 @@ def play_model(model_path, romfile):
print action_value, act
if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions()))
if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen:
act = 1
que.append(act)
print(act)
_, reward, isOver = player.action(act)
tot_reward += reward
if isOver:
print("Total:", tot_reward)
tot_reward = 0
pbar.update()
def eval_model_multiprocess(model_path, romfile):
M = Model()
......@@ -191,6 +196,7 @@ def eval_model_multiprocess(model_path, romfile):
self._init_runtime()
tot_reward = 0
que = deque(maxlen=30)
while True:
s = player.current_state()
outputs = self.func([[s]])
......@@ -199,6 +205,10 @@ def eval_model_multiprocess(model_path, romfile):
#print action_value, act
if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions()))
if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen:
act = 1
que.append(act)
#print(act)
_, reward, isOver = player.action(act)
tot_reward += reward
......@@ -215,16 +225,14 @@ def eval_model_multiprocess(model_path, romfile):
for k in procs:
k.start()
stat = StatCounter()
try:
EVAL_EPISODE = 50
with tqdm(total=EVAL_EPISODE) as pbar:
while True:
for _ in tqdm(range(EVAL_EPISODE)):
r = q.get()
stat.feed(r)
pbar.update()
if stat.count() == EVAL_EPISODE:
finally:
logger.info("Average Score: {}. Max Score: {}".format(
stat.average, stat.max))
break
def get_config(romfile):
......
......@@ -33,12 +33,13 @@ class AtariDriver(object):
self.viz = viz
self.romname = os.path.basename(rom_file)
if self.viz:
if isinstance(self.viz, float):
cv2.startWindowThread()
cv2.namedWindow(self.romname)
self._reset()
self.last_image = self._grab_raw_image()
self.framenum = 0
def _grab_raw_image(self):
"""
......@@ -55,9 +56,12 @@ class AtariDriver(object):
now = self._grab_raw_image()
ret = np.maximum(now, self.last_image)
self.last_image = now
if self.viz:
if isinstance(self.viz, float):
cv2.imshow(self.romname, ret)
time.sleep(self.viz)
else:
cv2.imwrite("{}/{:06d}.jpg".format(self.viz, self.framenum), ret)
self.framenum += 1
ret = cv2.cvtColor(ret, cv2.COLOR_BGR2YUV)[:,:,0]
return ret
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment