Commit ddf737d7 authored by Yuxin Wu's avatar Yuxin Wu

better infra for evaluate

parent 961b0ee4
......@@ -5,10 +5,11 @@
import tensorflow as tf
import numpy as np
import os, sys
import os, sys, re
import random
import argparse
from tqdm import tqdm
import subprocess
import multiprocessing
from collections import deque
......@@ -22,7 +23,7 @@ from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import *
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
from exp_replay import AtariExpReplay
from tensorpack.dataflow.RL import ExpReplay
"""
Implement DQN in:
......@@ -44,6 +45,8 @@ END_EXPLORATION = 0.1
INIT_MEMORY_SIZE = 50000
MEMORY_SIZE = 1e6
STEP_PER_EPOCH = 10000
EVAL_EPISODE = 100
class Model(ModelDesc):
......@@ -138,41 +141,47 @@ class ExpReplayController(Callback):
def _trigger_epoch(self):
if self.d.exploration > END_EXPLORATION:
self.d.exploration -= EXPLORATION_EPOCH_ANNEAL
logger.info("Exploration: {}".format(self.d.exploration))
logger.info("Exploration changed to {}".format(self.d.exploration))
def play_model(model_path, romfile):
player = AtariPlayer(AtariDriver(romfile, viz=0.01),
action_repeat=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions()
M = Model()
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg)
def play_one_episode(player, func, verbose=False):
tot_reward = 0
que = deque(maxlen=30)
while True:
s = player.current_state()
outputs = predfunc([[s]])
outputs = func([[s]])
action_value = outputs[0][0]
act = action_value.argmax()
if verbose:
print action_value, act
if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions()))
if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen:
act = 1
act = 1 # hack, avoid stuck
que.append(act)
if verbose:
print(act)
reward, isOver = player.action(act)
tot_reward += reward
if isOver:
print("Total:", tot_reward)
tot_reward = 0
return tot_reward
def play_model(model_path, romfile):
player = AtariPlayer(AtariDriver(romfile, viz=0.01),
action_repeat=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions()
M = Model()
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg)
while True:
score = play_one_episode(player, predfunc)
print("Total:", score)
def eval_model_multiprocess(model_path, romfile):
M = Model()
......@@ -192,29 +201,10 @@ def eval_model_multiprocess(model_path, romfile):
action_repeat=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions()
self._init_runtime()
tot_reward = 0
que = deque(maxlen=30)
while True:
s = player.current_state()
outputs = self.func([[s]])
action_value = outputs[0][0]
act = action_value.argmax()
#print action_value, act
if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions()))
if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen:
act = 1
que.append(act)
#print(act)
reward, isOver = player.action(act)
tot_reward += reward
if isOver:
self.outq.put(tot_reward)
tot_reward = 0
score = play_one_episode(player, self.func)
self.outq.put(score)
NR_PROC = min(multiprocessing.cpu_count() // 2, 10)
procs = []
......@@ -226,13 +216,19 @@ def eval_model_multiprocess(model_path, romfile):
k.start()
stat = StatCounter()
try:
EVAL_EPISODE = 50
for _ in tqdm(range(EVAL_EPISODE)):
r = q.get()
stat.feed(r)
finally:
logger.info("Average Score: {}. Max Score: {}".format(
for p in procs:
p.terminate()
p.join()
if stat.count() > 0:
logger.info("Average Score: {}; Max Score: {}".format(
stat.average, stat.max))
return (stat.average, stat.max)
else:
return (0, 0)
def get_config(romfile):
......@@ -260,6 +256,18 @@ def get_config(romfile):
lr = tf.Variable(0.0025, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
class Evaluator(Callback):
def _trigger_epoch(self):
logger.info("Evaluating...")
output = subprocess.check_output(
"""{} --task eval --rom {} --load {} 2>&1 | grep Average""".format(
sys.argv[0], romfile, os.path.join(logger.LOG_DIR, 'checkpoint')), shell=True)
output = output.strip()
output = output[output.find(']')+1:]
mean, maximum = re.findall('[0-9\.]+', output)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', maximum)
return TrainConfig(
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
......@@ -269,11 +277,12 @@ def get_config(romfile):
HumanHyperParamSetter('learning_rate', 'hyper.txt'),
HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'),
TargetNetworkUpdator(M),
ExpReplayController(dataset_train)
ExpReplayController(dataset_train),
PeriodicCallback(Evaluator(), 1),
]),
session_config=get_default_sess_config(0.5),
model=M,
step_per_epoch=10000,
step_per_epoch=STEP_PER_EPOCH,
max_epoch=10000,
)
......
......@@ -59,12 +59,6 @@ class Inferencer(object):
def _get_output_tensors(self):
pass
def _scalar_summary(self, name, val):
self.trainer.summary_writer.add_summary(
create_summary(name, val),
get_global_step())
self.trainer.stat_holder.add_stat(name, val)
class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
......@@ -161,9 +155,7 @@ class ScalarStats(Inferencer):
for stat, name in zip(self.stats, self.names):
opname, _ = get_op_var_name(name)
name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname
self.trainer.summary_writer.add_summary(
create_summary(name, stat), get_global_step())
self.trainer.stat_holder.add_stat(name, stat)
self.trainer.write_scalar_summary(name, stat)
class ClassificationError(Inferencer):
"""
......@@ -197,7 +189,7 @@ class ClassificationError(Inferencer):
self.err_stat.feed(wrong, batch_size)
def _after_inference(self):
self._scalar_summary(self.summary_name, self.err_stat.accuracy)
self.trainer.write_scalar_summary(self.summary_name, self.err_stat.accuracy)
class BinaryClassificationStats(Inferencer):
......@@ -221,5 +213,5 @@ class BinaryClassificationStats(Inferencer):
self.stat.feed(pred, label)
def _after_inference(self):
self._scalar_summary(self.prefix + '_precision', self.stat.precision)
self._scalar_summary(self.prefix + '_recall', self.stat.recall)
self.trainer.write_scalar_summary(self.prefix + '_precision', self.stat.precision)
self.trainer.write_scalar_summary(self.prefix + '_recall', self.stat.recall)
......@@ -58,8 +58,7 @@ class PredictConfig(object):
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto)
self.session_config = kwargs.pop('session_config', None)
self.session_init = kwargs.pop('session_init')
self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None)
......@@ -87,7 +86,10 @@ def get_predict_func(config):
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]
if config.session_config:
sess = tf.Session(config=config.session_config)
else:
sess = tf.Session()
config.session_init.init(sess)
def run_input(dp):
......@@ -116,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process):
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
else:
logger.info("Worker {} uses CPU".format(self.idx))
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0' if self.gpuid >= 0 else '/cpu:0'):
if self.idx != 0:
......
......@@ -12,6 +12,7 @@ from .config import TrainConfig
from ..utils import *
from ..callbacks import StatHolder
from ..tfutils import *
from ..tfutils.summary import create_summary
from ..tfutils.modelutils import describe_model
__all__ = ['Trainer']
......@@ -76,6 +77,12 @@ class Trainer(object):
self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
def write_scalar_summary(self, name, val):
self.summary_writer.add_summary(
create_summary(name, val),
get_global_step())
self.stat_holder.add_stat(name, val)
def main_loop(self):
# some final operations that might modify the graph
logger.info("Preparing for training...")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment