Commit da9b1b2f authored by Yuxin Wu's avatar Yuxin Wu

eval with multithread

parent b61ba3c9
...@@ -5,18 +5,22 @@ ...@@ -5,18 +5,22 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import os, sys, re
import os, sys, re, time
import random import random
import argparse import argparse
from tqdm import tqdm
import subprocess import subprocess
import multiprocessing import multiprocessing, threading
from collections import deque from collections import deque
from six.moves import queue
from tqdm import tqdm
from tensorpack import * from tensorpack import *
from tensorpack.models import * from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.concurrency import ensure_proc_terminate, subproc_call from tensorpack.utils.concurrency import (ensure_proc_terminate, \
subproc_call, StoppableThread)
from tensorpack.utils.stat import * from tensorpack.utils.stat import *
from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
...@@ -33,7 +37,7 @@ for atari games ...@@ -33,7 +37,7 @@ for atari games
BATCH_SIZE = 32 BATCH_SIZE = 32
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
ACTION_REPEAT = 4 ACTION_REPEAT = 3
HEIGHT_RANGE = (36, 204) # for breakout HEIGHT_RANGE = (36, 204) # for breakout
CHANNEL = FRAME_HISTORY CHANNEL = FRAME_HISTORY
IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,) IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
...@@ -64,7 +68,7 @@ class Model(ModelDesc): ...@@ -64,7 +68,7 @@ class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'), return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int32, (None,), 'action'), InputVar(tf.int64, (None,), 'action'),
InputVar(tf.float32, (None,), 'reward'), InputVar(tf.float32, (None,), 'reward'),
InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'next_state'), InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'next_state'),
InputVar(tf.bool, (None,), 'isOver') ] InputVar(tf.bool, (None,), 'isOver') ]
...@@ -72,7 +76,7 @@ class Model(ModelDesc): ...@@ -72,7 +76,7 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image, is_training): def _get_DQN_prediction(self, image, is_training):
""" image: [0,255]""" """ image: [0,255]"""
image = image / 255.0 image = image / 255.0
with argscope(Conv2D, nl=PReLU.f, use_bias=True): with argscope(Conv2D, nl=tf.nn.relu, use_bias=True):
l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=1) l = Conv2D('conv0', image, out_channel=32, kernel_shape=5, stride=1)
l = MaxPooling('pool0', l, 2) l = MaxPooling('pool0', l, 2)
l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=1) l = Conv2D('conv1', l, out_channel=32, kernel_shape=5, stride=1)
...@@ -80,8 +84,11 @@ class Model(ModelDesc): ...@@ -80,8 +84,11 @@ class Model(ModelDesc):
l = Conv2D('conv2', l, out_channel=64, kernel_shape=4) l = Conv2D('conv2', l, out_channel=64, kernel_shape=4)
l = MaxPooling('pool2', l, 2) l = MaxPooling('pool2', l, 2)
l = Conv2D('conv3', l, out_channel=64, kernel_shape=3) l = Conv2D('conv3', l, out_channel=64, kernel_shape=3)
#l = MaxPooling('pool3', l, 2)
#l = Conv2D('conv4', l, out_channel=64, kernel_shape=3) # the original arch
#l = Conv2D('conv0', image, out_channel=32, kernel_shape=8, stride=4)
#l = Conv2D('conv1', l, out_channel=64, kernel_shape=4, stride=2)
#l = Conv2D('conv2', l, out_channel=64, kernel_shape=3)
l = FullyConnected('fc0', l, 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name)) l = FullyConnected('fc0', l, 512, nl=lambda x, name: LeakyReLU.f(x, 0.01, name))
l = FullyConnected('fct', l, out_dim=NUM_ACTIONS, nl=tf.identity) l = FullyConnected('fct', l, out_dim=NUM_ACTIONS, nl=tf.identity)
...@@ -101,11 +108,11 @@ class Model(ModelDesc): ...@@ -101,11 +108,11 @@ class Model(ModelDesc):
targetQ_predict_value = self._get_DQN_prediction(next_state, False) # NxA targetQ_predict_value = self._get_DQN_prediction(next_state, False) # NxA
# DQN # DQN
best_v = tf.reduce_max(targetQ_predict_value, 1) # N, #best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
# Double-DQN # Double-DQN
#predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0) predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
#best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1) best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v) target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
...@@ -156,7 +163,7 @@ def play_one_episode(player, func, verbose=False): ...@@ -156,7 +163,7 @@ def play_one_episode(player, func, verbose=False):
return sc return sc
def play_model(model_path): def play_model(model_path):
player = PreventStuckPlayer(HistoryFramePlayer(get_player(0.01), FRAME_HISTORY), 30, 1) player = PreventStuckPlayer(HistoryFramePlayer(get_player(0.013), FRAME_HISTORY), 30, 1)
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
input_data_mapping=[0], input_data_mapping=[0],
...@@ -167,54 +174,61 @@ def play_model(model_path): ...@@ -167,54 +174,61 @@ def play_model(model_path):
score = play_one_episode(player, predfunc) score = play_one_episode(player, predfunc)
print("Total:", score) print("Total:", score)
def eval_model_multiprocess(model_path): def eval_with_funcs(predict_funcs):
cfg = PredictConfig( class Worker(StoppableThread):
model=Model(), def __init__(self, func, queue):
input_data_mapping=[0], super(Worker, self).__init__()
session_init=SaverRestore(model_path), self.func = func
output_var_names=['fct/output:0']) self.q = queue
class Worker(MultiProcessPredictWorker):
def __init__(self, idx, gpuid, config, outqueue):
super(Worker, self).__init__(idx, gpuid, config)
self.outq = outqueue
def run(self): def run(self):
player = PreventStuckPlayer(HistoryFramePlayer(get_player(), FRAME_HISTORY), 30, 1) player = PreventStuckPlayer(HistoryFramePlayer(get_player(), FRAME_HISTORY), 30, 1)
self._init_runtime() while not self.stopped():
while True:
score = play_one_episode(player, self.func) score = play_one_episode(player, self.func)
self.outq.put(score) while not self.stopped():
try:
self.q.put(score, timeout=5)
break
except queue.Queue.Full:
pass
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) q = queue.Queue()
q = multiprocessing.Queue() threads = [Worker(f, q) for f in predict_funcs]
gpuid = get_gpus()[0]
procs = [Worker(k, gpuid, cfg, q) for k in range(NR_PROC)] for k in threads:
ensure_proc_terminate(procs)
for k in procs:
k.start() k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter() stat = StatCounter()
try: try:
for _ in tqdm(range(EVAL_EPISODE)): for _ in tqdm(range(EVAL_EPISODE)):
r = q.get() r = q.get()
stat.feed(r) stat.feed(r)
for k in threads: k.stop()
for k in threads: k.join()
finally: finally:
logger.info("Average Score: {}; Max Score: {}".format( return (stat.average, stat.max)
stat.average, stat.max))
def eval_model_multithread(model_path):
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
p = get_player(); del p # set NUM_ACTIONS
func = get_predict_func(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback): class Evaluator(Callback):
def _before_train(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func(
['state'], ['fct/output'])] * NR_PROC
def _trigger_epoch(self): def _trigger_epoch(self):
logger.info("Evaluating...") mean, max = eval_with_funcs(self.pred_funcs)
output = subproc_call( self.trainer.write_scalar_summary('mean_score', mean)
"{} --task eval --rom {} --load {}".format( self.trainer.write_scalar_summary('max_score', max)
sys.argv[0], ROM_FILE, os.path.join(logger.LOG_DIR, 'checkpoint')),
timeout=10*60)
if output:
last = output.strip().split('\n')[-1]
last = last[last.find(']')+1:]
mean, maximum = re.findall('[0-9\.\-]+', last)[-2:]
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', maximum)
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
...@@ -277,7 +291,7 @@ if __name__ == '__main__': ...@@ -277,7 +291,7 @@ if __name__ == '__main__':
play_model(args.load) play_model(args.load)
sys.exit() sys.exit()
if args.task == 'eval': if args.task == 'eval':
eval_model_multiprocess(args.load) eval_model_multithread(args.load)
sys.exit() sys.exit()
with tf.Graph().as_default(): with tf.Graph().as_default():
......
...@@ -42,6 +42,8 @@ class AtariPlayer(RLEnvironment): ...@@ -42,6 +42,8 @@ class AtariPlayer(RLEnvironment):
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(0, 10000)) self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setBool("showinfo", False)
#ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
self.ale.setInt("frame_skip", 1) self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', False) self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check # manual.pdf suggests otherwise. may need to check
......
...@@ -89,7 +89,7 @@ class CallbackTimeLogger(object): ...@@ -89,7 +89,7 @@ class CallbackTimeLogger(object):
msgs = [] msgs = []
for name, t in self.times: for name, t in self.times:
if t / self.tot > 0.3 and t > 1: if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{:.3f}sec".format(name, t)) msgs.append("{}: {:.3f}sec".format(name, t))
logger.info( logger.info(
"Callbacks took {:.3f} sec in total. {}".format( "Callbacks took {:.3f} sec in total. {}".format(
self.tot, '; '.join(msgs))) self.tot, '; '.join(msgs)))
......
...@@ -79,8 +79,8 @@ def get_predict_func(config): ...@@ -79,8 +79,8 @@ def get_predict_func(config):
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1]) output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names] for n in output_var_names]
# start with minimal memory, but allow growth # XXX does it work? start with minimal memory, but allow growth
sess = tf.Session(config=get_default_sess_config(0.01)) sess = tf.Session(config=get_default_sess_config(0.3))
config.session_init.init(sess) config.session_init.init(sess)
def run_input(dp): def run_input(dp):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment