Commit 943b1701 authored by Yuxin Wu's avatar Yuxin Wu

separate common from DQN

parent 76cbc245
......@@ -13,20 +13,16 @@ import subprocess
import multiprocessing, threading
from collections import deque
from six.moves import queue
from tqdm import tqdm
from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.concurrency import (ensure_proc_terminate, \
subproc_call, StoppableThread)
from tensorpack.utils.stat import *
from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker
from tensorpack.utils.concurrency import *
from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import *
from tensorpack.RL import *
import common
from common import play_model, Evaluator, eval_model_multithread
BATCH_SIZE = 32
IMAGE_SIZE = (84, 84)
......@@ -56,16 +52,18 @@ def get_player(viz=False, train=False):
frame_skip=ACTION_REPEAT, image_shape=IMAGE_SIZE[::-1], viz=viz,
live_lost_as_eoe=train)
global NUM_ACTIONS
NUM_ACTIONS = pl.get_num_actions()
NUM_ACTIONS = pl.get_action_space().num_actions()
if not train:
pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 20000)
return pl
common.get_player = get_player # so that eval functions in common can use the player
class Model(ModelDesc):
def _get_input_vars(self):
assert NUM_ACTIONS is not None
if NUM_ACTIONS is None:
p = get_player(); del p
return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int64, (None,), 'action'),
InputVar(tf.float32, (None,), 'reward'),
......@@ -141,85 +139,6 @@ class Model(ModelDesc):
def predictor(self, state):
return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
def play_one_episode(player, func, verbose=False):
def f(s):
act = func([[s]])[0][0].argmax()
if random.random() < 0.01:
act = random.choice(range(NUM_ACTIONS))
if verbose:
print(act)
return act
return np.mean(player.play_one_episode(f))
def play_model(model_path):
player = get_player(viz=0.01)
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg)
while True:
score = play_one_episode(player, predfunc)
print("Total:", score)
def eval_with_funcs(predict_funcs, nr_eval=EVAL_EPISODE):
class Worker(StoppableThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
self.func = func
self.q = queue
def run(self):
player = get_player()
while not self.stopped():
score = play_one_episode(player, self.func)
self.queue_put_stoppable(self.q, score)
q = queue.Queue(maxsize=2)
threads = [Worker(f, q) for f in predict_funcs]
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
try:
for _ in tqdm(range(nr_eval)):
r = q.get()
stat.feed(r)
finally:
logger.info("Waiting for all the workers to finish the last run...")
for k in threads: k.stop()
for k in threads: k.join()
return (stat.average, stat.max)
def eval_model_multithread(model_path):
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
p = get_player(); del p # set NUM_ACTIONS
func = get_predict_func(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback):
def _before_train(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func(
['state'], ['fct/output'])] * NR_PROC
self.eval_episode = EVAL_EPISODE
def _trigger_epoch(self):
t = time.time()
mean, max = eval_with_funcs(self.pred_funcs, nr_eval=self.eval_episode)
t = time.time() - t
if t > 8 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.89)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', max)
def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
......@@ -229,10 +148,9 @@ def get_config():
dataset_train = ExpReplay(
predictor=M.predictor,
player=get_player(train=True),
num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE,
batch_size=BATCH_SIZE,
populate_size=INIT_MEMORY_SIZE,
memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE,
exploration=INIT_EXPLORATION,
end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
......@@ -253,7 +171,7 @@ def get_config():
HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
RunOp(lambda: M.update_target_param()),
dataset_train,
PeriodicCallback(Evaluator(), 2),
PeriodicCallback(Evaluator(EVAL_EPISODE), 2),
]),
# save memory for multiprocess evaluator
session_config=get_default_sess_config(0.3),
......@@ -272,20 +190,15 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.task != 'train':
assert args.load is not None
ROM_FILE = args.rom
if args.task == 'play':
play_model(args.load)
sys.exit()
if args.task == 'eval':
eval_model_multithread(args.load)
sys.exit()
with tf.Graph().as_default():
play_model(Model(), args.load)
elif args.task == 'eval':
eval_model_multithread(Model(), args.load, EVAL_EPISODE)
else:
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import random, time
import threading, multiprocessing
import numpy as np
from tqdm import tqdm
from six.moves import queue
from tensorpack import *
from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker
from tensorpack.utils.concurrency import *
from tensorpack.utils.stat import *
from tensorpack.callbacks import *
global get_player
def play_one_episode(player, func, verbose=False):
# 0.01-greedy evaluation
def f(s):
spc = player.get_action_space()
act = func([[s]])[0][0].argmax()
if random.random() < 0.01:
act = spc.sample()
if verbose:
print(act)
return act
return np.mean(player.play_one_episode(f))
def play_model(M, model_path):
player = get_player(viz=0.01)
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg)
while True:
score = play_one_episode(player, predfunc)
print("Total:", score)
def eval_with_funcs(predict_funcs, nr_eval):
class Worker(StoppableThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
self.func = func
self.q = queue
def run(self):
player = get_player()
while not self.stopped():
score = play_one_episode(player, self.func)
self.queue_put_stoppable(self.q, score)
q = queue.Queue(maxsize=2)
threads = [Worker(f, q) for f in predict_funcs]
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
try:
for _ in tqdm(range(nr_eval)):
r = q.get()
stat.feed(r)
except:
logger.exception("Eval")
finally:
logger.info("Waiting for all the workers to finish the last run...")
for k in threads: k.stop()
for k in threads: k.join()
if stat.count > 0:
return (stat.average, stat.max)
return (0, 0)
def eval_model_multithread(M, model_path, nr_eval):
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
func = get_predict_func(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC, nr_eval)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback):
def __init__(self, nr_eval):
self.eval_episode = nr_eval
def _before_train(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func(
['state'], ['fct/output'])] * NR_PROC
def _trigger_epoch(self):
t = time.time()
mean, max = eval_with_funcs(self.pred_funcs, nr_eval=self.eval_episode)
t = time.time() - t
if t > 8 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.89)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', max)
......@@ -52,7 +52,6 @@ class Model(ModelDesc):
l = tf.nn.dropout(l, keep_prob)
l = FullyConnected('fc1', l, 512,
b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=self.cifar_classnum, nl=tf.identity)
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
......
......@@ -44,7 +44,6 @@ class Model(ModelDesc):
l = tf.nn.dropout(l, keep_prob)
l = FullyConnected('fc0', l, 512,
b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')
......
......@@ -12,7 +12,7 @@ from six.moves import range
from ..utils import get_rng, logger, memoized
from ..utils.stat import StatCounter
from .envbase import RLEnvironment
from .envbase import RLEnvironment, DiscreteActionSpace
try:
from ale_python_interface import ALEInterface
......@@ -104,6 +104,7 @@ class AtariPlayer(RLEnvironment):
ret = np.maximum(ret, self.last_raw_screen)
if self.viz:
if isinstance(self.viz, float):
#m = cv2.resize(ret, (1920,1200))
cv2.imshow(self.windowname, ret)
time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1],:]
......@@ -113,11 +114,8 @@ class AtariPlayer(RLEnvironment):
ret = np.expand_dims(ret, axis=2)
return ret
def get_num_actions(self):
"""
:returns: the number of legal actions
"""
return len(self.actions)
def get_action_space(self):
return DiscreteActionSpace(len(self.actions))
def restart_episode(self):
if self.current_episode_score.count > 0:
......@@ -170,7 +168,7 @@ if __name__ == '__main__':
def benchmark():
a = AtariPlayer(sys.argv[1], viz=False, height_range=(28,-8))
num = a.get_num_actions()
num = a.get_action_space().num_actions()
rng = get_rng(num)
start = time.time()
cnt = 0
......@@ -194,7 +192,7 @@ if __name__ == '__main__':
else:
a = AtariPlayer(sys.argv[1],
viz=0.03, height_range=(28,-8))
num = a.get_num_actions()
num = a.get_action_space().num_actions()
rng = get_rng(num)
import time
while True:
......
......@@ -6,8 +6,11 @@
from abc import abstractmethod, ABCMeta
from collections import defaultdict
import random
from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer']
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace']
class RLEnvironment(object):
__meta__ = ABCMeta
......@@ -33,6 +36,10 @@ class RLEnvironment(object):
""" Start a new episode, even if the current hasn't ended """
raise NotImplementedError()
def get_action_space(self):
""" return an `ActionSpace` instance"""
raise NotImplementedError()
def get_stat(self):
"""
return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat
......@@ -40,7 +47,7 @@ class RLEnvironment(object):
return {}
def reset_stat(self):
""" reset the statistics counter"""
""" reset all statistics counter"""
self.stats = defaultdict(list)
def play_one_episode(self, func, stat='score'):
......@@ -57,6 +64,28 @@ class RLEnvironment(object):
self.reset_stat()
return s
class ActionSpace(object):
def __init__(self):
self.rng = get_rng(self)
@abstractmethod
def sample(self):
pass
def num_actions(self):
raise NotImplementedError()
class DiscreteActionSpace(ActionSpace):
def __init__(self, num):
super(DiscreteActionSpace, self).__init__()
self.num = num
def sample(self):
return self.rng.randint(self.num)
def num_actions(self):
return self.num
class NaiveRLEnvironment(RLEnvironment):
""" for testing only"""
def __init__(self):
......@@ -67,8 +96,6 @@ class NaiveRLEnvironment(RLEnvironment):
def action(self, act):
self.k = act
return (self.k, self.k > 10)
def restart_episode(self):
pass
class ProxyPlayer(RLEnvironment):
""" Serve as a proxy another player """
......@@ -93,3 +120,6 @@ class ProxyPlayer(RLEnvironment):
def restart_episode(self):
self.player.restart_episode()
def get_action_space(self):
return self.player.get_action_space()
......@@ -25,10 +25,10 @@ class ExpReplay(DataFlow, Callback):
def __init__(self,
predictor,
player,
num_actions,
memory_size=1e6,
batch_size=32,
populate_size=50000,
memory_size=1e6,
populate_size=None, # deprecated
init_memory_size=50000,
exploration=1,
end_exploration=0.1,
exploration_epoch_anneal=0.002,
......@@ -37,20 +37,27 @@ class ExpReplay(DataFlow, Callback):
history_len=1
):
"""
:param predictor: a callabale calling the up-to-date network.
called with a state, return a distribution
:param player: a `RLEnvironment`
:param num_actions: int
:param predictor: a callabale running the up-to-date network.
called with a state, return a distribution.
:param player: an `RLEnvironment`
:param history_len: length of history frames to concat. zero-filled initial frames
:param update_frequency: number of new transitions to add to memory
after sampling a batch of transitions for training
"""
# XXX back-compat
if populate_size is not None:
logger.warn("populate_size in ExpReplay is deprecated in favor of init_memory_size")
init_memory_size = populate_size
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.num_actions = player.get_action_space().num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self)
def init_memory(self):
def _init_memory(self):
logger.info("Populating replay memory...")
# fill some for the history
......@@ -60,8 +67,8 @@ class ExpReplay(DataFlow, Callback):
self._populate_exp()
self.exploration = old_exploration
with tqdm(total=self.populate_size) as pbar:
while len(self.mem) < self.populate_size:
with tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size:
self._populate_exp()
pbar.update()
......@@ -96,7 +103,7 @@ class ExpReplay(DataFlow, Callback):
def get_data(self):
# new s is considered useless if isOver==True
while True:
batch_exp = [self.sample_one() for _ in range(self.batch_size)]
batch_exp = [self._sample_one() for _ in range(self.batch_size)]
#import cv2
#def view_state(state, next_state):
......@@ -116,7 +123,7 @@ class ExpReplay(DataFlow, Callback):
for _ in range(self.update_frequency):
self._populate_exp()
def sample_one(self):
def _sample_one(self):
""" return the transition tuple for
[idx, idx+history_len] -> [idx+1, idx+1+history_len]
it's the transition from state idx+history_len-1 to state idx+history_len
......@@ -155,14 +162,14 @@ class ExpReplay(DataFlow, Callback):
return [state, action, reward, next_state, isOver]
# Callback-related:
def _before_train(self):
self.init_memory()
self._init_memory()
def _trigger_epoch(self):
if self.exploration > self.end_exploration:
self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration))
# log player statistics
stats = self.player.get_stat()
for k, v in six.iteritems(stats):
if isinstance(v, float):
......@@ -177,10 +184,10 @@ if __name__ == '__main__':
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor,
player=player,
num_actions=player.get_num_actions(),
num_actions=player.get_action_space().num_actions(),
populate_size=1001,
history_len=4)
E.init_memory()
E._init_memory()
for k in E.get_data():
import IPython as IP;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment