Commit 943b1701 authored by Yuxin Wu's avatar Yuxin Wu

separate common from DQN

parent 76cbc245
...@@ -13,20 +13,16 @@ import subprocess ...@@ -13,20 +13,16 @@ import subprocess
import multiprocessing, threading import multiprocessing, threading
from collections import deque from collections import deque
from six.moves import queue
from tqdm import tqdm
from tensorpack import * from tensorpack import *
from tensorpack.models import * from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.concurrency import (ensure_proc_terminate, \ from tensorpack.utils.concurrency import *
subproc_call, StoppableThread)
from tensorpack.utils.stat import *
from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.RL import * from tensorpack.RL import *
import common
from common import play_model, Evaluator, eval_model_multithread
BATCH_SIZE = 32 BATCH_SIZE = 32
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
...@@ -56,16 +52,18 @@ def get_player(viz=False, train=False): ...@@ -56,16 +52,18 @@ def get_player(viz=False, train=False):
frame_skip=ACTION_REPEAT, image_shape=IMAGE_SIZE[::-1], viz=viz, frame_skip=ACTION_REPEAT, image_shape=IMAGE_SIZE[::-1], viz=viz,
live_lost_as_eoe=train) live_lost_as_eoe=train)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = pl.get_num_actions() NUM_ACTIONS = pl.get_action_space().num_actions()
if not train: if not train:
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1) pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 20000) pl = LimitLengthPlayer(pl, 20000)
return pl return pl
common.get_player = get_player # so that eval functions in common can use the player
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
assert NUM_ACTIONS is not None if NUM_ACTIONS is None:
p = get_player(); del p
return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'), return [InputVar(tf.float32, (None,) + IMAGE_SHAPE3, 'state'),
InputVar(tf.int64, (None,), 'action'), InputVar(tf.int64, (None,), 'action'),
InputVar(tf.float32, (None,), 'reward'), InputVar(tf.float32, (None,), 'reward'),
...@@ -141,85 +139,6 @@ class Model(ModelDesc): ...@@ -141,85 +139,6 @@ class Model(ModelDesc):
def predictor(self, state): def predictor(self, state):
return self.predict_value.eval(feed_dict={'state:0': [state]})[0] return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
def play_one_episode(player, func, verbose=False):
def f(s):
act = func([[s]])[0][0].argmax()
if random.random() < 0.01:
act = random.choice(range(NUM_ACTIONS))
if verbose:
print(act)
return act
return np.mean(player.play_one_episode(f))
def play_model(model_path):
player = get_player(viz=0.01)
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg)
while True:
score = play_one_episode(player, predfunc)
print("Total:", score)
def eval_with_funcs(predict_funcs, nr_eval=EVAL_EPISODE):
class Worker(StoppableThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
self.func = func
self.q = queue
def run(self):
player = get_player()
while not self.stopped():
score = play_one_episode(player, self.func)
self.queue_put_stoppable(self.q, score)
q = queue.Queue(maxsize=2)
threads = [Worker(f, q) for f in predict_funcs]
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
try:
for _ in tqdm(range(nr_eval)):
r = q.get()
stat.feed(r)
finally:
logger.info("Waiting for all the workers to finish the last run...")
for k in threads: k.stop()
for k in threads: k.join()
return (stat.average, stat.max)
def eval_model_multithread(model_path):
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
p = get_player(); del p # set NUM_ACTIONS
func = get_predict_func(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback):
def _before_train(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func(
['state'], ['fct/output'])] * NR_PROC
self.eval_episode = EVAL_EPISODE
def _trigger_epoch(self):
t = time.time()
mean, max = eval_with_funcs(self.pred_funcs, nr_eval=self.eval_episode)
t = time.time() - t
if t > 8 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.89)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', max)
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
logger.set_logger_dir( logger.set_logger_dir(
...@@ -229,10 +148,9 @@ def get_config(): ...@@ -229,10 +148,9 @@ def get_config():
dataset_train = ExpReplay( dataset_train = ExpReplay(
predictor=M.predictor, predictor=M.predictor,
player=get_player(train=True), player=get_player(train=True),
num_actions=NUM_ACTIONS,
memory_size=MEMORY_SIZE,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
populate_size=INIT_MEMORY_SIZE, memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE,
exploration=INIT_EXPLORATION, exploration=INIT_EXPLORATION,
end_exploration=END_EXPLORATION, end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL, exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
...@@ -253,7 +171,7 @@ def get_config(): ...@@ -253,7 +171,7 @@ def get_config():
HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'), HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
RunOp(lambda: M.update_target_param()), RunOp(lambda: M.update_target_param()),
dataset_train, dataset_train,
PeriodicCallback(Evaluator(), 2), PeriodicCallback(Evaluator(EVAL_EPISODE), 2),
]), ]),
# save memory for multiprocess evaluator # save memory for multiprocess evaluator
session_config=get_default_sess_config(0.3), session_config=get_default_sess_config(0.3),
...@@ -272,20 +190,15 @@ if __name__ == '__main__': ...@@ -272,20 +190,15 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
if args.task != 'train': if args.task != 'train':
assert args.load is not None assert args.load is not None
ROM_FILE = args.rom ROM_FILE = args.rom
if args.task == 'play': if args.task == 'play':
play_model(args.load) play_model(Model(), args.load)
sys.exit() elif args.task == 'eval':
if args.task == 'eval': eval_model_multithread(Model(), args.load, EVAL_EPISODE)
eval_model_multithread(args.load) else:
sys.exit()
with tf.Graph().as_default():
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import random, time
import threading, multiprocessing
import numpy as np
from tqdm import tqdm
from six.moves import queue
from tensorpack import *
from tensorpack.predict import PredictConfig, get_predict_func, MultiProcessPredictWorker
from tensorpack.utils.concurrency import *
from tensorpack.utils.stat import *
from tensorpack.callbacks import *
global get_player
def play_one_episode(player, func, verbose=False):
# 0.01-greedy evaluation
def f(s):
spc = player.get_action_space()
act = func([[s]])[0][0].argmax()
if random.random() < 0.01:
act = spc.sample()
if verbose:
print(act)
return act
return np.mean(player.play_one_episode(f))
def play_model(M, model_path):
player = get_player(viz=0.01)
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
predfunc = get_predict_func(cfg)
while True:
score = play_one_episode(player, predfunc)
print("Total:", score)
def eval_with_funcs(predict_funcs, nr_eval):
class Worker(StoppableThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
self.func = func
self.q = queue
def run(self):
player = get_player()
while not self.stopped():
score = play_one_episode(player, self.func)
self.queue_put_stoppable(self.q, score)
q = queue.Queue(maxsize=2)
threads = [Worker(f, q) for f in predict_funcs]
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
try:
for _ in tqdm(range(nr_eval)):
r = q.get()
stat.feed(r)
except:
logger.exception("Eval")
finally:
logger.info("Waiting for all the workers to finish the last run...")
for k in threads: k.stop()
for k in threads: k.join()
if stat.count > 0:
return (stat.average, stat.max)
return (0, 0)
def eval_model_multithread(M, model_path, nr_eval):
cfg = PredictConfig(
model=M,
input_data_mapping=[0],
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
func = get_predict_func(cfg)
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
mean, max = eval_with_funcs([func] * NR_PROC, nr_eval)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
class Evaluator(Callback):
def __init__(self, nr_eval):
self.eval_episode = nr_eval
def _before_train(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func(
['state'], ['fct/output'])] * NR_PROC
def _trigger_epoch(self):
t = time.time()
mean, max = eval_with_funcs(self.pred_funcs, nr_eval=self.eval_episode)
t = time.time() - t
if t > 8 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.89)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', max)
...@@ -52,7 +52,6 @@ class Model(ModelDesc): ...@@ -52,7 +52,6 @@ class Model(ModelDesc):
l = tf.nn.dropout(l, keep_prob) l = tf.nn.dropout(l, keep_prob)
l = FullyConnected('fc1', l, 512, l = FullyConnected('fc1', l, 512,
b_init=tf.constant_initializer(0.1)) b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=self.cifar_classnum, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=self.cifar_classnum, nl=tf.identity)
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
......
...@@ -44,7 +44,6 @@ class Model(ModelDesc): ...@@ -44,7 +44,6 @@ class Model(ModelDesc):
l = tf.nn.dropout(l, keep_prob) l = tf.nn.dropout(l, keep_prob)
l = FullyConnected('fc0', l, 512, l = FullyConnected('fc0', l, 512,
b_init=tf.constant_initializer(0.1)) b_init=tf.constant_initializer(0.1))
# fc will have activation summary by default. disable for the output layer
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
......
...@@ -12,7 +12,7 @@ from six.moves import range ...@@ -12,7 +12,7 @@ from six.moves import range
from ..utils import get_rng, logger, memoized from ..utils import get_rng, logger, memoized
from ..utils.stat import StatCounter from ..utils.stat import StatCounter
from .envbase import RLEnvironment from .envbase import RLEnvironment, DiscreteActionSpace
try: try:
from ale_python_interface import ALEInterface from ale_python_interface import ALEInterface
...@@ -104,6 +104,7 @@ class AtariPlayer(RLEnvironment): ...@@ -104,6 +104,7 @@ class AtariPlayer(RLEnvironment):
ret = np.maximum(ret, self.last_raw_screen) ret = np.maximum(ret, self.last_raw_screen)
if self.viz: if self.viz:
if isinstance(self.viz, float): if isinstance(self.viz, float):
#m = cv2.resize(ret, (1920,1200))
cv2.imshow(self.windowname, ret) cv2.imshow(self.windowname, ret)
time.sleep(self.viz) time.sleep(self.viz)
ret = ret[self.height_range[0]:self.height_range[1],:] ret = ret[self.height_range[0]:self.height_range[1],:]
...@@ -113,11 +114,8 @@ class AtariPlayer(RLEnvironment): ...@@ -113,11 +114,8 @@ class AtariPlayer(RLEnvironment):
ret = np.expand_dims(ret, axis=2) ret = np.expand_dims(ret, axis=2)
return ret return ret
def get_num_actions(self): def get_action_space(self):
""" return DiscreteActionSpace(len(self.actions))
:returns: the number of legal actions
"""
return len(self.actions)
def restart_episode(self): def restart_episode(self):
if self.current_episode_score.count > 0: if self.current_episode_score.count > 0:
...@@ -170,7 +168,7 @@ if __name__ == '__main__': ...@@ -170,7 +168,7 @@ if __name__ == '__main__':
def benchmark(): def benchmark():
a = AtariPlayer(sys.argv[1], viz=False, height_range=(28,-8)) a = AtariPlayer(sys.argv[1], viz=False, height_range=(28,-8))
num = a.get_num_actions() num = a.get_action_space().num_actions()
rng = get_rng(num) rng = get_rng(num)
start = time.time() start = time.time()
cnt = 0 cnt = 0
...@@ -194,7 +192,7 @@ if __name__ == '__main__': ...@@ -194,7 +192,7 @@ if __name__ == '__main__':
else: else:
a = AtariPlayer(sys.argv[1], a = AtariPlayer(sys.argv[1],
viz=0.03, height_range=(28,-8)) viz=0.03, height_range=(28,-8))
num = a.get_num_actions() num = a.get_action_space().num_actions()
rng = get_rng(num) rng = get_rng(num)
import time import time
while True: while True:
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict from collections import defaultdict
import random
from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer'] __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace']
class RLEnvironment(object): class RLEnvironment(object):
__meta__ = ABCMeta __meta__ = ABCMeta
...@@ -33,6 +36,10 @@ class RLEnvironment(object): ...@@ -33,6 +36,10 @@ class RLEnvironment(object):
""" Start a new episode, even if the current hasn't ended """ """ Start a new episode, even if the current hasn't ended """
raise NotImplementedError() raise NotImplementedError()
def get_action_space(self):
""" return an `ActionSpace` instance"""
raise NotImplementedError()
def get_stat(self): def get_stat(self):
""" """
return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat
...@@ -40,7 +47,7 @@ class RLEnvironment(object): ...@@ -40,7 +47,7 @@ class RLEnvironment(object):
return {} return {}
def reset_stat(self): def reset_stat(self):
""" reset the statistics counter""" """ reset all statistics counter"""
self.stats = defaultdict(list) self.stats = defaultdict(list)
def play_one_episode(self, func, stat='score'): def play_one_episode(self, func, stat='score'):
...@@ -57,6 +64,28 @@ class RLEnvironment(object): ...@@ -57,6 +64,28 @@ class RLEnvironment(object):
self.reset_stat() self.reset_stat()
return s return s
class ActionSpace(object):
def __init__(self):
self.rng = get_rng(self)
@abstractmethod
def sample(self):
pass
def num_actions(self):
raise NotImplementedError()
class DiscreteActionSpace(ActionSpace):
def __init__(self, num):
super(DiscreteActionSpace, self).__init__()
self.num = num
def sample(self):
return self.rng.randint(self.num)
def num_actions(self):
return self.num
class NaiveRLEnvironment(RLEnvironment): class NaiveRLEnvironment(RLEnvironment):
""" for testing only""" """ for testing only"""
def __init__(self): def __init__(self):
...@@ -67,8 +96,6 @@ class NaiveRLEnvironment(RLEnvironment): ...@@ -67,8 +96,6 @@ class NaiveRLEnvironment(RLEnvironment):
def action(self, act): def action(self, act):
self.k = act self.k = act
return (self.k, self.k > 10) return (self.k, self.k > 10)
def restart_episode(self):
pass
class ProxyPlayer(RLEnvironment): class ProxyPlayer(RLEnvironment):
""" Serve as a proxy another player """ """ Serve as a proxy another player """
...@@ -93,3 +120,6 @@ class ProxyPlayer(RLEnvironment): ...@@ -93,3 +120,6 @@ class ProxyPlayer(RLEnvironment):
def restart_episode(self): def restart_episode(self):
self.player.restart_episode() self.player.restart_episode()
def get_action_space(self):
return self.player.get_action_space()
...@@ -25,10 +25,10 @@ class ExpReplay(DataFlow, Callback): ...@@ -25,10 +25,10 @@ class ExpReplay(DataFlow, Callback):
def __init__(self, def __init__(self,
predictor, predictor,
player, player,
num_actions,
memory_size=1e6,
batch_size=32, batch_size=32,
populate_size=50000, memory_size=1e6,
populate_size=None, # deprecated
init_memory_size=50000,
exploration=1, exploration=1,
end_exploration=0.1, end_exploration=0.1,
exploration_epoch_anneal=0.002, exploration_epoch_anneal=0.002,
...@@ -37,20 +37,27 @@ class ExpReplay(DataFlow, Callback): ...@@ -37,20 +37,27 @@ class ExpReplay(DataFlow, Callback):
history_len=1 history_len=1
): ):
""" """
:param predictor: a callabale calling the up-to-date network. :param predictor: a callabale running the up-to-date network.
called with a state, return a distribution called with a state, return a distribution.
:param player: a `RLEnvironment` :param player: an `RLEnvironment`
:param num_actions: int
:param history_len: length of history frames to concat. zero-filled initial frames :param history_len: length of history frames to concat. zero-filled initial frames
:param update_frequency: number of new transitions to add to memory
after sampling a batch of transitions for training
""" """
# XXX back-compat
if populate_size is not None:
logger.warn("populate_size in ExpReplay is deprecated in favor of init_memory_size")
init_memory_size = populate_size
for k, v in locals().items(): for k, v in locals().items():
if k != 'self': if k != 'self':
setattr(self, k, v) setattr(self, k, v)
self.num_actions = player.get_action_space().num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions)) logger.info("Number of Legal actions: {}".format(self.num_actions))
self.mem = deque(maxlen=memory_size) self.mem = deque(maxlen=memory_size)
self.rng = get_rng(self) self.rng = get_rng(self)
def init_memory(self): def _init_memory(self):
logger.info("Populating replay memory...") logger.info("Populating replay memory...")
# fill some for the history # fill some for the history
...@@ -60,8 +67,8 @@ class ExpReplay(DataFlow, Callback): ...@@ -60,8 +67,8 @@ class ExpReplay(DataFlow, Callback):
self._populate_exp() self._populate_exp()
self.exploration = old_exploration self.exploration = old_exploration
with tqdm(total=self.populate_size) as pbar: with tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.populate_size: while len(self.mem) < self.init_memory_size:
self._populate_exp() self._populate_exp()
pbar.update() pbar.update()
...@@ -96,7 +103,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -96,7 +103,7 @@ class ExpReplay(DataFlow, Callback):
def get_data(self): def get_data(self):
# new s is considered useless if isOver==True # new s is considered useless if isOver==True
while True: while True:
batch_exp = [self.sample_one() for _ in range(self.batch_size)] batch_exp = [self._sample_one() for _ in range(self.batch_size)]
#import cv2 #import cv2
#def view_state(state, next_state): #def view_state(state, next_state):
...@@ -116,7 +123,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -116,7 +123,7 @@ class ExpReplay(DataFlow, Callback):
for _ in range(self.update_frequency): for _ in range(self.update_frequency):
self._populate_exp() self._populate_exp()
def sample_one(self): def _sample_one(self):
""" return the transition tuple for """ return the transition tuple for
[idx, idx+history_len] -> [idx+1, idx+1+history_len] [idx, idx+history_len] -> [idx+1, idx+1+history_len]
it's the transition from state idx+history_len-1 to state idx+history_len it's the transition from state idx+history_len-1 to state idx+history_len
...@@ -155,14 +162,14 @@ class ExpReplay(DataFlow, Callback): ...@@ -155,14 +162,14 @@ class ExpReplay(DataFlow, Callback):
return [state, action, reward, next_state, isOver] return [state, action, reward, next_state, isOver]
# Callback-related: # Callback-related:
def _before_train(self): def _before_train(self):
self.init_memory() self._init_memory()
def _trigger_epoch(self): def _trigger_epoch(self):
if self.exploration > self.end_exploration: if self.exploration > self.end_exploration:
self.exploration -= self.exploration_epoch_anneal self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration)) logger.info("Exploration changed to {}".format(self.exploration))
# log player statistics
stats = self.player.get_stat() stats = self.player.get_stat()
for k, v in six.iteritems(stats): for k, v in six.iteritems(stats):
if isinstance(v, float): if isinstance(v, float):
...@@ -177,10 +184,10 @@ if __name__ == '__main__': ...@@ -177,10 +184,10 @@ if __name__ == '__main__':
player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204)) player = AtariPlayer(sys.argv[1], viz=0, frame_skip=10, height_range=(36, 204))
E = ExpReplay(predictor, E = ExpReplay(predictor,
player=player, player=player,
num_actions=player.get_num_actions(), num_actions=player.get_action_space().num_actions(),
populate_size=1001, populate_size=1001,
history_len=4) history_len=4)
E.init_memory() E._init_memory()
for k in E.get_data(): for k in E.get_data():
import IPython as IP; import IPython as IP;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment