Commit ddf737d7 authored by Yuxin Wu's avatar Yuxin Wu

better infra for evaluate

parent 961b0ee4
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import os, sys import os, sys, re
import random import random
import argparse import argparse
from tqdm import tqdm from tqdm import tqdm
import subprocess
import multiprocessing import multiprocessing
from collections import deque from collections import deque
...@@ -22,7 +23,7 @@ from tensorpack.tfutils import symbolic_functions as symbf ...@@ -22,7 +23,7 @@ from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
from exp_replay import AtariExpReplay from tensorpack.dataflow.RL import ExpReplay
""" """
Implement DQN in: Implement DQN in:
...@@ -44,6 +45,8 @@ END_EXPLORATION = 0.1 ...@@ -44,6 +45,8 @@ END_EXPLORATION = 0.1
INIT_MEMORY_SIZE = 50000 INIT_MEMORY_SIZE = 50000
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
STEP_PER_EPOCH = 10000
EVAL_EPISODE = 100
class Model(ModelDesc): class Model(ModelDesc):
...@@ -138,41 +141,47 @@ class ExpReplayController(Callback): ...@@ -138,41 +141,47 @@ class ExpReplayController(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
if self.d.exploration > END_EXPLORATION: if self.d.exploration > END_EXPLORATION:
self.d.exploration -= EXPLORATION_EPOCH_ANNEAL self.d.exploration -= EXPLORATION_EPOCH_ANNEAL
logger.info("Exploration: {}".format(self.d.exploration)) logger.info("Exploration changed to {}".format(self.d.exploration))
def play_model(model_path, romfile): def play_one_episode(player, func, verbose=False):
player = AtariPlayer(AtariDriver(romfile, viz=0.01),
action_repeat=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions()
M = Model()
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg)
tot_reward = 0 tot_reward = 0
que = deque(maxlen=30) que = deque(maxlen=30)
while True: while True:
s = player.current_state() s = player.current_state()
outputs = predfunc([[s]]) outputs = func([[s]])
action_value = outputs[0][0] action_value = outputs[0][0]
act = action_value.argmax() act = action_value.argmax()
if verbose:
print action_value, act print action_value, act
if random.random() < 0.01: if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions())) act = random.choice(range(player.driver.get_num_actions()))
if len(que) == que.maxlen \ if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen: and que.count(que[0]) == que.maxlen:
act = 1 act = 1 # hack, avoid stuck
que.append(act) que.append(act)
if verbose:
print(act) print(act)
reward, isOver = player.action(act) reward, isOver = player.action(act)
tot_reward += reward tot_reward += reward
if isOver: if isOver:
print("Total:", tot_reward) return tot_reward
tot_reward = 0
def play_model(model_path, romfile):
player = AtariPlayer(AtariDriver(romfile, viz=0.01),
action_repeat=ACTION_REPEAT)
global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions()
M = Model()
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg)
while True:
score = play_one_episode(player, predfunc)
print("Total:", score)
def eval_model_multiprocess(model_path, romfile): def eval_model_multiprocess(model_path, romfile):
M = Model() M = Model()
...@@ -192,29 +201,10 @@ def eval_model_multiprocess(model_path, romfile): ...@@ -192,29 +201,10 @@ def eval_model_multiprocess(model_path, romfile):
action_repeat=ACTION_REPEAT) action_repeat=ACTION_REPEAT)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = player.driver.get_num_actions() NUM_ACTIONS = player.driver.get_num_actions()
self._init_runtime() self._init_runtime()
tot_reward = 0
que = deque(maxlen=30)
while True: while True:
s = player.current_state() score = play_one_episode(player, self.func)
outputs = self.func([[s]]) self.outq.put(score)
action_value = outputs[0][0]
act = action_value.argmax()
#print action_value, act
if random.random() < 0.01:
act = random.choice(range(player.driver.get_num_actions()))
if len(que) == que.maxlen \
and que.count(que[0]) == que.maxlen:
act = 1
que.append(act)
#print(act)
reward, isOver = player.action(act)
tot_reward += reward
if isOver:
self.outq.put(tot_reward)
tot_reward = 0
NR_PROC = min(multiprocessing.cpu_count() // 2, 10) NR_PROC = min(multiprocessing.cpu_count() // 2, 10)
procs = [] procs = []
...@@ -226,13 +216,19 @@ def eval_model_multiprocess(model_path, romfile): ...@@ -226,13 +216,19 @@ def eval_model_multiprocess(model_path, romfile):
k.start() k.start()
stat = StatCounter() stat = StatCounter()
try: try:
EVAL_EPISODE = 50
for _ in tqdm(range(EVAL_EPISODE)): for _ in tqdm(range(EVAL_EPISODE)):
r = q.get() r = q.get()
stat.feed(r) stat.feed(r)
finally: finally:
logger.info("Average Score: {}. Max Score: {}".format( for p in procs:
p.terminate()
p.join()
if stat.count() > 0:
logger.info("Average Score: {}; Max Score: {}".format(
stat.average, stat.max)) stat.average, stat.max))
return (stat.average, stat.max)
else:
return (0, 0)
def get_config(romfile): def get_config(romfile):
...@@ -260,6 +256,18 @@ def get_config(romfile): ...@@ -260,6 +256,18 @@ def get_config(romfile):
lr = tf.Variable(0.0025, trainable=False, name='learning_rate') lr = tf.Variable(0.0025, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
class Evaluator(Callback):
def _trigger_epoch(self):
logger.info("Evaluating...")
output = subprocess.check_output(
"""{} --task eval --rom {} --load {} 2>&1 | grep Average""".format(
sys.argv[0], romfile, os.path.join(logger.LOG_DIR, 'checkpoint')), shell=True)
output = output.strip()
output = output[output.find(']')+1:]
mean, maximum = re.findall('[0-9\.]+', output)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', maximum)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
...@@ -269,11 +277,12 @@ def get_config(romfile): ...@@ -269,11 +277,12 @@ def get_config(romfile):
HumanHyperParamSetter('learning_rate', 'hyper.txt'), HumanHyperParamSetter('learning_rate', 'hyper.txt'),
HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'), HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'),
TargetNetworkUpdator(M), TargetNetworkUpdator(M),
ExpReplayController(dataset_train) ExpReplayController(dataset_train),
PeriodicCallback(Evaluator(), 1),
]), ]),
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=M, model=M,
step_per_epoch=10000, step_per_epoch=STEP_PER_EPOCH,
max_epoch=10000, max_epoch=10000,
) )
......
...@@ -59,12 +59,6 @@ class Inferencer(object): ...@@ -59,12 +59,6 @@ class Inferencer(object):
def _get_output_tensors(self): def _get_output_tensors(self):
pass pass
def _scalar_summary(self, name, val):
self.trainer.summary_writer.add_summary(
create_summary(name, val),
get_global_step())
self.trainer.stat_holder.add_stat(name, val)
class InferenceRunner(Callback): class InferenceRunner(Callback):
""" """
A callback that runs different kinds of inferencer. A callback that runs different kinds of inferencer.
...@@ -161,9 +155,7 @@ class ScalarStats(Inferencer): ...@@ -161,9 +155,7 @@ class ScalarStats(Inferencer):
for stat, name in zip(self.stats, self.names): for stat, name in zip(self.stats, self.names):
opname, _ = get_op_var_name(name) opname, _ = get_op_var_name(name)
name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname
self.trainer.summary_writer.add_summary( self.trainer.write_scalar_summary(name, stat)
create_summary(name, stat), get_global_step())
self.trainer.stat_holder.add_stat(name, stat)
class ClassificationError(Inferencer): class ClassificationError(Inferencer):
""" """
...@@ -197,7 +189,7 @@ class ClassificationError(Inferencer): ...@@ -197,7 +189,7 @@ class ClassificationError(Inferencer):
self.err_stat.feed(wrong, batch_size) self.err_stat.feed(wrong, batch_size)
def _after_inference(self): def _after_inference(self):
self._scalar_summary(self.summary_name, self.err_stat.accuracy) self.trainer.write_scalar_summary(self.summary_name, self.err_stat.accuracy)
class BinaryClassificationStats(Inferencer): class BinaryClassificationStats(Inferencer):
...@@ -221,5 +213,5 @@ class BinaryClassificationStats(Inferencer): ...@@ -221,5 +213,5 @@ class BinaryClassificationStats(Inferencer):
self.stat.feed(pred, label) self.stat.feed(pred, label)
def _after_inference(self): def _after_inference(self):
self._scalar_summary(self.prefix + '_precision', self.stat.precision) self.trainer.write_scalar_summary(self.prefix + '_precision', self.stat.precision)
self._scalar_summary(self.prefix + '_recall', self.stat.recall) self.trainer.write_scalar_summary(self.prefix + '_recall', self.stat.recall)
...@@ -58,8 +58,7 @@ class PredictConfig(object): ...@@ -58,8 +58,7 @@ class PredictConfig(object):
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
self.session_config = kwargs.pop('session_config', get_default_sess_config()) self.session_config = kwargs.pop('session_config', None)
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init') self.session_init = kwargs.pop('session_init')
self.model = kwargs.pop('model') self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None) self.input_data_mapping = kwargs.pop('input_data_mapping', None)
...@@ -87,7 +86,10 @@ def get_predict_func(config): ...@@ -87,7 +86,10 @@ def get_predict_func(config):
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1]) output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names] for n in output_var_names]
if config.session_config:
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
else:
sess = tf.Session()
config.session_init.init(sess) config.session_init.init(sess)
def run_input(dp): def run_input(dp):
...@@ -116,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process): ...@@ -116,7 +118,7 @@ class ParallelPredictWorker(multiprocessing.Process):
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
else: else:
logger.info("Worker {} uses CPU".format(self.idx)) logger.info("Worker {} uses CPU".format(self.idx))
os.environ['CUDA_VISIBLE_DEVICES'] = '' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
G = tf.Graph() # build a graph for each process, because they don't need to share anything G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:0' if self.gpuid >= 0 else '/cpu:0'): with G.as_default(), tf.device('/gpu:0' if self.gpuid >= 0 else '/cpu:0'):
if self.idx != 0: if self.idx != 0:
......
...@@ -12,6 +12,7 @@ from .config import TrainConfig ...@@ -12,6 +12,7 @@ from .config import TrainConfig
from ..utils import * from ..utils import *
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import create_summary
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
__all__ = ['Trainer'] __all__ = ['Trainer']
...@@ -76,6 +77,12 @@ class Trainer(object): ...@@ -76,6 +77,12 @@ class Trainer(object):
self.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step) self.summary_writer.add_summary(summary, self.global_step)
def write_scalar_summary(self, name, val):
self.summary_writer.add_summary(
create_summary(name, val),
get_global_step())
self.stat_holder.add_stat(name, val)
def main_loop(self): def main_loop(self):
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Preparing for training...") logger.info("Preparing for training...")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment