Commit 6e1f395d authored by Yuxin Wu's avatar Yuxin Wu

exp_replay as a callback

parent ddf737d7
......@@ -43,12 +43,11 @@ INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL = 0.0025
END_EXPLORATION = 0.1
INIT_MEMORY_SIZE = 50000
MEMORY_SIZE = 1e6
INIT_MEMORY_SIZE = 50000
STEP_PER_EPOCH = 10000
EVAL_EPISODE = 100
class Model(ModelDesc):
def _get_input_vars(self):
assert NUM_ACTIONS is not None
......@@ -131,18 +130,6 @@ class TargetNetworkUpdator(Callback):
def _trigger_epoch(self):
self._update()
class ExpReplayController(Callback):
def __init__(self, d):
self.d = d
def _before_train(self):
self.d.init_memory()
def _trigger_epoch(self):
if self.d.exploration > END_EXPLORATION:
self.d.exploration -= EXPLORATION_EPOCH_ANNEAL
logger.info("Exploration changed to {}".format(self.d.exploration))
def play_one_episode(player, func, verbose=False):
tot_reward = 0
que = deque(maxlen=30)
......@@ -251,6 +238,8 @@ def get_config(romfile):
batch_size=BATCH_SIZE,
populate_size=INIT_MEMORY_SIZE,
exploration=INIT_EXPLORATION,
end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
reward_clip=(-1, 2))
lr = tf.Variable(0.0025, trainable=False, name='learning_rate')
......@@ -277,7 +266,7 @@ def get_config(romfile):
HumanHyperParamSetter('learning_rate', 'hyper.txt'),
HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'),
TargetNetworkUpdator(M),
ExpReplayController(dataset_train),
dataset_train,
PeriodicCallback(Evaluator(), 1),
]),
session_config=get_default_sess_config(0.5),
......
......@@ -5,6 +5,7 @@
from .base import DataFlow
from tensorpack.utils import *
from tensorpack.callbacks.base import Callback
from tqdm import tqdm
import random
......@@ -12,6 +13,7 @@ import numpy as np
import cv2
from collections import deque, namedtuple
"""
Implement RL-related data preprocessing
"""
......@@ -28,7 +30,7 @@ def view_state(state):
cv2.imshow("state", r)
cv2.waitKey()
class ExpReplay(DataFlow):
class ExpReplay(DataFlow, Callback):
"""
Implement experience replay.
"""
......@@ -40,6 +42,8 @@ class ExpReplay(DataFlow):
batch_size=32,
populate_size=50000,
exploration=1,
end_exploration=0.1,
exploration_epoch_anneal=0.002,
reward_clip=None):
"""
:param predictor: callabale. called with a state, return a distribution
......@@ -102,6 +106,16 @@ class ExpReplay(DataFlow):
isOver[idx] = b.isOver
return [state, action, reward, next_state, isOver]
# Callback-related:
def _before_train(self):
self.init_memory()
def _trigger_epoch(self):
if self.exploration > self.end_exploration:
self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration))
if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
predictor = lambda x: np.array([1,1,1,1])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment