Commit 6e1f395d authored by Yuxin Wu's avatar Yuxin Wu

exp_replay as a callback

parent ddf737d7
...@@ -43,12 +43,11 @@ INIT_EXPLORATION = 1 ...@@ -43,12 +43,11 @@ INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL = 0.0025 EXPLORATION_EPOCH_ANNEAL = 0.0025
END_EXPLORATION = 0.1 END_EXPLORATION = 0.1
INIT_MEMORY_SIZE = 50000
MEMORY_SIZE = 1e6 MEMORY_SIZE = 1e6
INIT_MEMORY_SIZE = 50000
STEP_PER_EPOCH = 10000 STEP_PER_EPOCH = 10000
EVAL_EPISODE = 100 EVAL_EPISODE = 100
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
assert NUM_ACTIONS is not None assert NUM_ACTIONS is not None
...@@ -131,18 +130,6 @@ class TargetNetworkUpdator(Callback): ...@@ -131,18 +130,6 @@ class TargetNetworkUpdator(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
self._update() self._update()
class ExpReplayController(Callback):
def __init__(self, d):
self.d = d
def _before_train(self):
self.d.init_memory()
def _trigger_epoch(self):
if self.d.exploration > END_EXPLORATION:
self.d.exploration -= EXPLORATION_EPOCH_ANNEAL
logger.info("Exploration changed to {}".format(self.d.exploration))
def play_one_episode(player, func, verbose=False): def play_one_episode(player, func, verbose=False):
tot_reward = 0 tot_reward = 0
que = deque(maxlen=30) que = deque(maxlen=30)
...@@ -251,6 +238,8 @@ def get_config(romfile): ...@@ -251,6 +238,8 @@ def get_config(romfile):
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
populate_size=INIT_MEMORY_SIZE, populate_size=INIT_MEMORY_SIZE,
exploration=INIT_EXPLORATION, exploration=INIT_EXPLORATION,
end_exploration=END_EXPLORATION,
exploration_epoch_anneal=EXPLORATION_EPOCH_ANNEAL,
reward_clip=(-1, 2)) reward_clip=(-1, 2))
lr = tf.Variable(0.0025, trainable=False, name='learning_rate') lr = tf.Variable(0.0025, trainable=False, name='learning_rate')
...@@ -277,7 +266,7 @@ def get_config(romfile): ...@@ -277,7 +266,7 @@ def get_config(romfile):
HumanHyperParamSetter('learning_rate', 'hyper.txt'), HumanHyperParamSetter('learning_rate', 'hyper.txt'),
HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'), HumanHyperParamSetter((dataset_train, 'exploration'), 'hyper.txt'),
TargetNetworkUpdator(M), TargetNetworkUpdator(M),
ExpReplayController(dataset_train), dataset_train,
PeriodicCallback(Evaluator(), 1), PeriodicCallback(Evaluator(), 1),
]), ]),
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from .base import DataFlow from .base import DataFlow
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.callbacks.base import Callback
from tqdm import tqdm from tqdm import tqdm
import random import random
...@@ -12,6 +13,7 @@ import numpy as np ...@@ -12,6 +13,7 @@ import numpy as np
import cv2 import cv2
from collections import deque, namedtuple from collections import deque, namedtuple
""" """
Implement RL-related data preprocessing Implement RL-related data preprocessing
""" """
...@@ -28,7 +30,7 @@ def view_state(state): ...@@ -28,7 +30,7 @@ def view_state(state):
cv2.imshow("state", r) cv2.imshow("state", r)
cv2.waitKey() cv2.waitKey()
class ExpReplay(DataFlow): class ExpReplay(DataFlow, Callback):
""" """
Implement experience replay. Implement experience replay.
""" """
...@@ -40,6 +42,8 @@ class ExpReplay(DataFlow): ...@@ -40,6 +42,8 @@ class ExpReplay(DataFlow):
batch_size=32, batch_size=32,
populate_size=50000, populate_size=50000,
exploration=1, exploration=1,
end_exploration=0.1,
exploration_epoch_anneal=0.002,
reward_clip=None): reward_clip=None):
""" """
:param predictor: callabale. called with a state, return a distribution :param predictor: callabale. called with a state, return a distribution
...@@ -102,6 +106,16 @@ class ExpReplay(DataFlow): ...@@ -102,6 +106,16 @@ class ExpReplay(DataFlow):
isOver[idx] = b.isOver isOver[idx] = b.isOver
return [state, action, reward, next_state, isOver] return [state, action, reward, next_state, isOver]
# Callback-related:
def _before_train(self):
self.init_memory()
def _trigger_epoch(self):
if self.exploration > self.end_exploration:
self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration))
if __name__ == '__main__': if __name__ == '__main__':
from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer from tensorpack.dataflow.dataset import AtariDriver, AtariPlayer
predictor = lambda x: np.array([1,1,1,1]) predictor = lambda x: np.array([1,1,1,1])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment