Commit f417c49f authored by Yuxin Wu's avatar Yuxin Wu

[DQN] make DQN more generic: remove some constants & globals

parent 0b561b3b
......@@ -19,23 +19,17 @@ from expreplay import ExpReplay
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
STATE_SHAPE = None # IMAGE_SIZE + (3,) in gym, and IMAGE_SIZE in ALE
FRAME_HISTORY = 4
ACTION_REPEAT = 4 # aka FRAME_SKIP
UPDATE_FREQ = 4
GAMMA = 0.99
MEMORY_SIZE = 1e6
# will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE = MEMORY_SIZE // 20
STEPS_PER_EPOCH = 100000 // UPDATE_FREQ # each epoch is 100k played frames
EVAL_EPISODE = 50
NUM_ACTIONS = None
USE_GYM = False
ENV_NAME = None
METHOD = None
def resize_keepdims(im, size):
......@@ -51,7 +45,8 @@ def get_player(viz=False, train=False):
env = gym.make(ENV_NAME)
else:
from atari import AtariPlayer
env = AtariPlayer(ENV_NAME, frame_skip=ACTION_REPEAT, viz=viz,
# frame_skip=4 is what's used in the original paper
env = AtariPlayer(ENV_NAME, frame_skip=4, viz=viz,
live_lost_as_eoe=train, max_num_frames=60000)
env = FireResetEnv(env)
env = MapState(env, lambda im: resize_keepdims(im, IMAGE_SIZE))
......@@ -67,16 +62,14 @@ class Model(DQNModel):
"""
A DQN model for 2D/3D (image) observations.
"""
def __init__(self):
assert len(STATE_SHAPE) in [2, 3]
super(Model, self).__init__(STATE_SHAPE, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA)
def _get_DQN_prediction(self, image):
assert image.shape.rank in [4, 5], image.shape
# image: N, H, W, (C), Hist
if image.shape.rank == 5:
# merge C & Hist
image = tf.reshape(image, [-1] + list(STATE_SHAPE[:2]) + [STATE_SHAPE[2] * FRAME_HISTORY])
image = tf.reshape(
image,
[-1] + list(self.state_shape[:2]) + [self.state_shape[2] * FRAME_HISTORY])
image = image / 255.0
with argscope(Conv2D, activation=lambda x: PReLU('prelu', x), use_bias=True):
......@@ -107,22 +100,23 @@ class Model(DQNModel):
return tf.identity(Q, name='Qvalue')
def get_config():
def get_config(model):
expreplay = ExpReplay(
predictor_io_names=(['state'], ['Qvalue']),
player=get_player(train=True),
state_shape=STATE_SHAPE,
state_shape=model.state_shape,
batch_size=BATCH_SIZE,
memory_size=MEMORY_SIZE,
init_memory_size=INIT_MEMORY_SIZE,
init_exploration=1.0,
update_frequency=UPDATE_FREQ,
history_len=FRAME_HISTORY
history_len=FRAME_HISTORY,
state_dtype=model.state_dtype.as_numpy_dtype
)
return TrainConfig(
data=QueueInput(expreplay),
model=Model(),
model=model,
callbacks=[
ModelSaver(),
PeriodicTrigger(
......@@ -130,7 +124,7 @@ def get_config():
every_k_steps=10000 // UPDATE_FREQ), # update target network every 10k steps
expreplay,
ScheduledHyperParamSetter('learning_rate',
[(60, 4e-4), (100, 2e-4), (500, 5e-5)]),
[(0, 1e-3), (60, 4e-4), (100, 2e-4), (500, 5e-5)]),
ScheduledHyperParamSetter(
ObjAttrParam(expreplay, 'exploration'),
[(0, 1), (10, 0.1), (320, 0.01)], # 1->0.1 in the first million steps
......@@ -156,33 +150,35 @@ if __name__ == '__main__':
parser.add_argument('--algo', help='algorithm',
choices=['DQN', 'Double', 'Dueling'], default='Double')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
ENV_NAME = args.env
USE_GYM = not ENV_NAME.endswith('.bin')
STATE_SHAPE = IMAGE_SIZE + (3, ) if USE_GYM else IMAGE_SIZE
METHOD = args.algo
# set num_actions
NUM_ACTIONS = get_player().action_space.n
logger.info("ENV: {}, Num Actions: {}".format(ENV_NAME, NUM_ACTIONS))
num_actions = get_player().action_space.n
logger.info("ENV: {}, Num Actions: {}".format(args.env, num_actions))
state_shape = IMAGE_SIZE + (3, ) if USE_GYM else IMAGE_SIZE
model = Model(state_shape, FRAME_HISTORY, args.algo, num_actions)
if args.task != 'train':
assert args.load is not None
pred = OfflinePredictor(PredictConfig(
model=Model(),
model=model,
session_init=get_model_loader(args.load),
input_names=['state'],
output_names=['Qvalue']))
if args.task == 'play':
play_n_episodes(get_player(viz=0.01), pred, 100)
play_n_episodes(get_player(viz=0.01), pred, 100, render=True)
elif args.task == 'eval':
eval_model_multithread(pred, EVAL_EPISODE, get_player)
else:
logger.set_logger_dir(
os.path.join('train_log', 'DQN-{}'.format(
os.path.basename(ENV_NAME).split('.')[0])))
config = get_config()
os.path.basename(args.env).split('.')[0])))
config = get_config(model)
if args.load:
config.session_init = get_model_loader(args.load)
launch_train_with_config(config, SimpleTrainer())
......@@ -13,29 +13,29 @@ from tensorpack.utils import logger
class Model(ModelDesc):
learning_rate = 1e-3
state_dtype = tf.uint8
def __init__(self, state_shape, history, method, num_actions, gamma):
# reward discount factor
gamma = 0.99
def __init__(self, state_shape, history, method, num_actions):
"""
Args:
state_shape (tuple[int]),
history (int):
"""
self._state_shape = tuple(state_shape)
self._stacked_state_shape = (-1, ) + self._state_shape + (history, )
self.state_shape = tuple(state_shape)
self._stacked_state_shape = (-1, ) + self.state_shape + (history, )
self.history = history
self.method = method
self.num_actions = num_actions
self.gamma = gamma
def inputs(self):
# When we use h history frames, the current state and the next state will have (h-1) overlapping frames.
# Therefore we use a combined state for efficiency:
# The first h are the current state, and the last h are the next state.
return [tf.placeholder(self.state_dtype,
(None,) + self._state_shape + (self.history + 1, ),
(None,) + self.state_shape + (self.history + 1, ),
'comb_state'),
tf.placeholder(tf.int64, (None,), 'action'),
tf.placeholder(tf.float32, (None,), 'reward'),
......@@ -101,7 +101,7 @@ class Model(ModelDesc):
return cost
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=self.learning_rate, trainable=False)
lr = tf.get_variable('learning_rate', initializer=1e-3, trainable=False)
opt = tf.train.RMSPropOptimizer(lr, epsilon=1e-5)
return optimizer.apply_grad_processors(opt, [gradproc.SummaryGradient()])
......
......@@ -140,6 +140,9 @@ class AtariPlayer(gym.Env):
self._restart_episode()
return self._current_state()
def render(self, *args, **kwargs):
pass # visualization for this env is through the viz= argument when creating the player
def step(self, act):
oldlives = self.ale.lives()
r = 0
......
......@@ -22,22 +22,24 @@ Experience = namedtuple('Experience',
class ReplayMemory(object):
def __init__(self, max_size, state_shape, history_len):
def __init__(self, max_size, state_shape, history_len, dtype='uint8'):
"""
Args:
state_shape (tuple[int]): shape (without history) of state
dtype: numpy dtype for the state
"""
self.max_size = int(max_size)
self.state_shape = state_shape
assert len(state_shape) in [1, 2, 3], state_shape
self._output_shape = self.state_shape + (history_len + 1, )
self.history_len = int(history_len)
self.dtype = dtype
all_state_shape = (self.max_size,) + state_shape
logger.info("Creating experience replay buffer of {:.1f} GB ... "
"use a smaller buffer if you don't have enough CPU memory.".format(
np.prod(all_state_shape) / 1024.0**3))
self.state = np.zeros(all_state_shape, dtype='uint8')
self.state = np.zeros(all_state_shape, dtype=self.dtype)
self.action = np.zeros((self.max_size,), dtype='int32')
self.reward = np.zeros((self.max_size,), dtype='float32')
self.isOver = np.zeros((self.max_size,), dtype='bool')
......@@ -66,7 +68,7 @@ class ReplayMemory(object):
def recent_state(self):
""" return a list of ``hist_len-1`` elements, each of shape ``self.state_shape`` """
lst = list(self._hist)
states = [np.zeros(self.state_shape, dtype='uint8')] * (self._hist.maxlen - len(lst))
states = [np.zeros(self.state_shape, dtype=self.dtype)] * (self._hist.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
......@@ -137,7 +139,8 @@ class ExpReplay(DataFlow, Callback):
batch_size,
memory_size, init_memory_size,
init_exploration,
update_frequency, history_len):
update_frequency, history_len,
state_dtype='uint8'):
"""
Args:
predictor_io_names (tuple of list of str): input/output names to
......@@ -219,11 +222,14 @@ class ExpReplay(DataFlow, Callback):
self._current_ob, reward, isOver, info = self.player.step(act)
self._current_game_score.feed(reward)
if isOver:
# handle ale-specific information
if info.get('ale.lives', -1) == 0:
if 'ale.lives' in info: # if running Atari, do something special for logging:
if info['ale.lives'] == 0:
# only record score when a whole game is over (not when an episode is over)
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
else:
self._player_scores.feed(self._current_game_score.sum)
self._current_game_score.reset()
self.player.reset()
self.mem.append(Experience(old_s, act, reward, isOver))
......@@ -244,7 +250,7 @@ class ExpReplay(DataFlow, Callback):
view_state(sample[0])
def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
state = np.asarray([e[0] for e in batch_exp], dtype=self.state_dtype)
reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment