Commit da9b1b2f authored by Yuxin Wu's avatar Yuxin Wu

eval with multithread

parent b61ba3c9
......@@ -5,18 +5,22 @@
import numpy as np
import tensorflow as tf
import os, sys, re
import os, sys, re, time
import random
import argparse
from tqdm import tqdm
import subprocess
import multiprocessing
import multiprocessing, threading
from collections import deque
from six.moves import queue
from tqdm import tqdm
from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.concurrency import ensure_proc_terminate, subproc_call
from tensorpack.utils.concurrency import (ensure_proc_terminate, \
subproc_call, StoppableThread)
from tensorpack.utils.stat import *
from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker
from tensorpack.tfutils import symbolic_functions as symbf
......@@ -33,7 +37,7 @@ for atari games
BATCH_SIZE = 32
IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4
ACTION_REPEAT = 4
ACTION_REPEAT = 3
HEIGHT_RANGE = (36, 204) # for breakout
CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
......@@ -64,7 +68,7 @@ class Model(ModelDesc):
def _get_input_vars(self):
assert NUM_ACTIONS is not None
return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int32, (None,), 'action'),
InputVar(tf.int64, (None,), 'action'),
InputVar(tf.float32, (None,), 'reward'),
InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'next_state'),
InputVar(tf.bool, (None,), 'isOver') ]
......@@ -72,7 +76,7 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image, is_training):
""" image: [0,255]"""
image = image / 255.0
with argscope(Conv2D, nl=PReLU.f, use_bias=True):
with argscope(Conv2D, nl=tf.nn.relu, use_bias=True):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=1)
l = MaxPooling('pool0', l, 2)
l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=1)
......@@ -80,8 +84,11 @@ class Model(ModelDesc):
l = Conv2D('conv2', l, out_channel=64, kernel_shape=4)
l = MaxPooling('pool2', l, 2)
l = Conv2D('conv3', l, out_channel=64, kernel_shape=3)
#l = MaxPooling('pool3', l, 2)
#l = Conv2D('conv4', l, out_channel=64, kernel_shape=3)
# the original arch
#l = Conv2D('conv0', image, out_channel=32, kernel_shape=8, stride=4)
#l = Conv2D('conv1', l, out_channel=64, kernel_shape=4, stride=2)
#l = Conv2D('conv2', l, out_channel=64, kernel_shape=3)
l = FullyConnected('fc0', l, 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))
l = FullyConnected('fct', l, out_dim=NUM_ACTIONS, nl=tf.identity)
......@@ -101,11 +108,11 @@ class Model(ModelDesc):
targetQ_predict_value = self._get_DQN_prediction(next_state, False) # NxA
# DQN
best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
#best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
# Double-DQN
#predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
#best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
......@@ -156,7 +163,7 @@ def play_one_episode(player, func, verbose=False):
return sc
def play_model(model_path):
player = PreventStuckPlayer(HistoryFramePlayer(get_player(0.01), FRAME_HISTORY), 30, 1)
player = PreventStuckPlayer(HistoryFramePlayer(get_player(0.013), FRAME_HISTORY), 30, 1)
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
......@@ -167,54 +174,61 @@ def play_model(model_path):
score = play_one_episode(player, predfunc)
print("Total:", score)
def eval_model_multiprocess(model_path):
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
class Worker(MultiProcessPredictWorker):
def __init__(self, idx, gpuid, config, outqueue):
super(Worker, self).__init__(idx, gpuid, config)
self.outq = outqueue
def eval_with_funcs(predict_funcs):
class Worker(StoppableThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
self.func = func
self.q = queue
def run(self):
player = PreventStuckPlayer(HistoryFramePlayer(get_player(), FRAME_HISTORY), 30, 1)
self._init_runtime()
while True:
while not self.stopped():
score = play_one_episode(player, self.func)
self.outq.put(score)
while not self.stopped():
try:
self.q.put(score, timeout=5)
break
except queue.Queue.Full:
pass
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
q = multiprocessing.Queue()
gpuid = get_gpus()[0]
procs = [Worker(k, gpuid, cfg, q) for k in range(NR_PROC)]
ensure_proc_terminate(procs)
for k in procs:
q = queue.Queue()
threads = [Worker(f, q) for f in predict_funcs]
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
try:
for _ in tqdm(range(EVAL_EPISODE)):
r = q.get()
stat.feed(r)
for k in threads: k.stop()
for k in threads: k.join()
finally:
logger.info("Average Score: {}; Max Score: {}".format(
stat.average, stat.max))
return (stat.average, stat.max)
def eval_model_multithread(model_path):
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
p = get_player(); del p # set NUM_ACTIONS
func = get_predict_func(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback):
def _before_train(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func(
['state'], ['fct/output'])] * NR_PROC
def _trigger_epoch(self):
logger.info("Evaluating...")
output = subproc_call(
"{} --task eval --rom {} --load {}".format(
sys.argv[0], ROM_FILE, os.path.join(logger.LOG_DIR, 'checkpoint')),
timeout=10*60)
if output:
last = output.strip().split('\n')[-1]
last = last[last.find(']')+1:]
mean, maximum = re.findall('[0-9\.\-]+', last)[-2:]
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', maximum)
mean, max = eval_with_funcs(self.pred_funcs)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', max)
def get_config():
basename = os.path.basename(__file__)
......@@ -277,7 +291,7 @@ if __name__ == '__main__':
play_model(args.load)
sys.exit()
if args.task == 'eval':
eval_model_multiprocess(args.load)
eval_model_multithread(args.load)
sys.exit()
with tf.Graph().as_default():
......
......@@ -42,6 +42,8 @@ class AtariPlayer(RLEnvironment):
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setBool("showinfo", False)
#ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check
......
......@@ -89,7 +89,7 @@ class CallbackTimeLogger(object):
msgs = []
for name, t in self.times:
if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{:.3f}sec".format(name, t))
msgs.append("{}: {:.3f}sec".format(name, t))
logger.info(
"Callbacks took {:.3f} sec in total. {}".format(
self.tot, '; '.join(msgs)))
......
......@@ -79,8 +79,8 @@ def get_predict_func(config):
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]
# start with minimal memory, but allow growth
sess = tf.Session(config=get_default_sess_config(0.01))
# XXX does it work? start with minimal memory, but allow growth
sess = tf.Session(config=get_default_sess_config(0.3))
config.session_init.init(sess)
def run_input(dp):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment