Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f417c49f
Commit
f417c49f
authored
Mar 06, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[DQN] make DQN more generic: remove some constants & globals
parent
0b561b3b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
42 deletions
+47
-42
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+22
-26
examples/DeepQNetwork/DQNModel.py
examples/DeepQNetwork/DQNModel.py
+8
-8
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+3
-0
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+14
-8
No files found.
examples/DeepQNetwork/DQN.py
View file @
f417c49f
...
...
@@ -19,23 +19,17 @@ from expreplay import ExpReplay
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
STATE_SHAPE
=
None
# IMAGE_SIZE + (3,) in gym, and IMAGE_SIZE in ALE
FRAME_HISTORY
=
4
ACTION_REPEAT
=
4
# aka FRAME_SKIP
UPDATE_FREQ
=
4
GAMMA
=
0.99
MEMORY_SIZE
=
1e6
# will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE
=
MEMORY_SIZE
//
20
STEPS_PER_EPOCH
=
100000
//
UPDATE_FREQ
# each epoch is 100k played frames
EVAL_EPISODE
=
50
NUM_ACTIONS
=
None
USE_GYM
=
False
ENV_NAME
=
None
METHOD
=
None
def
resize_keepdims
(
im
,
size
):
...
...
@@ -51,7 +45,8 @@ def get_player(viz=False, train=False):
env
=
gym
.
make
(
ENV_NAME
)
else
:
from
atari
import
AtariPlayer
env
=
AtariPlayer
(
ENV_NAME
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
# frame_skip=4 is what's used in the original paper
env
=
AtariPlayer
(
ENV_NAME
,
frame_skip
=
4
,
viz
=
viz
,
live_lost_as_eoe
=
train
,
max_num_frames
=
60000
)
env
=
FireResetEnv
(
env
)
env
=
MapState
(
env
,
lambda
im
:
resize_keepdims
(
im
,
IMAGE_SIZE
))
...
...
@@ -67,16 +62,14 @@ class Model(DQNModel):
"""
A DQN model for 2D/3D (image) observations.
"""
def
__init__
(
self
):
assert
len
(
STATE_SHAPE
)
in
[
2
,
3
]
super
(
Model
,
self
)
.
__init__
(
STATE_SHAPE
,
FRAME_HISTORY
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
def
_get_DQN_prediction
(
self
,
image
):
assert
image
.
shape
.
rank
in
[
4
,
5
],
image
.
shape
# image: N, H, W, (C), Hist
if
image
.
shape
.
rank
==
5
:
# merge C & Hist
image
=
tf
.
reshape
(
image
,
[
-
1
]
+
list
(
STATE_SHAPE
[:
2
])
+
[
STATE_SHAPE
[
2
]
*
FRAME_HISTORY
])
image
=
tf
.
reshape
(
image
,
[
-
1
]
+
list
(
self
.
state_shape
[:
2
])
+
[
self
.
state_shape
[
2
]
*
FRAME_HISTORY
])
image
=
image
/
255.0
with
argscope
(
Conv2D
,
activation
=
lambda
x
:
PReLU
(
'prelu'
,
x
),
use_bias
=
True
):
...
...
@@ -107,22 +100,23 @@ class Model(DQNModel):
return
tf
.
identity
(
Q
,
name
=
'Qvalue'
)
def
get_config
():
def
get_config
(
model
):
expreplay
=
ExpReplay
(
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
player
=
get_player
(
train
=
True
),
state_shape
=
STATE_SHAPE
,
state_shape
=
model
.
state_shape
,
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
init_exploration
=
1.0
,
update_frequency
=
UPDATE_FREQ
,
history_len
=
FRAME_HISTORY
history_len
=
FRAME_HISTORY
,
state_dtype
=
model
.
state_dtype
.
as_numpy_dtype
)
return
TrainConfig
(
data
=
QueueInput
(
expreplay
),
model
=
Model
()
,
model
=
model
,
callbacks
=
[
ModelSaver
(),
PeriodicTrigger
(
...
...
@@ -130,7 +124,7 @@ def get_config():
every_k_steps
=
10000
//
UPDATE_FREQ
),
# update target network every 10k steps
expreplay
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
60
,
4e-4
),
(
100
,
2e-4
),
(
500
,
5e-5
)]),
[(
0
,
1e-3
),
(
60
,
4e-4
),
(
100
,
2e-4
),
(
500
,
5e-5
)]),
ScheduledHyperParamSetter
(
ObjAttrParam
(
expreplay
,
'exploration'
),
[(
0
,
1
),
(
10
,
0.1
),
(
320
,
0.01
)],
# 1->0.1 in the first million steps
...
...
@@ -156,33 +150,35 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--algo'
,
help
=
'algorithm'
,
choices
=
[
'DQN'
,
'Double'
,
'Dueling'
],
default
=
'Double'
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
ENV_NAME
=
args
.
env
USE_GYM
=
not
ENV_NAME
.
endswith
(
'.bin'
)
STATE_SHAPE
=
IMAGE_SIZE
+
(
3
,
)
if
USE_GYM
else
IMAGE_SIZE
METHOD
=
args
.
algo
# set num_actions
NUM_ACTIONS
=
get_player
()
.
action_space
.
n
logger
.
info
(
"ENV: {}, Num Actions: {}"
.
format
(
ENV_NAME
,
NUM_ACTIONS
))
num_actions
=
get_player
()
.
action_space
.
n
logger
.
info
(
"ENV: {}, Num Actions: {}"
.
format
(
args
.
env
,
num_actions
))
state_shape
=
IMAGE_SIZE
+
(
3
,
)
if
USE_GYM
else
IMAGE_SIZE
model
=
Model
(
state_shape
,
FRAME_HISTORY
,
args
.
algo
,
num_actions
)
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
()
,
model
=
model
,
session_init
=
get_model_loader
(
args
.
load
),
input_names
=
[
'state'
],
output_names
=
[
'Qvalue'
]))
if
args
.
task
==
'play'
:
play_n_episodes
(
get_player
(
viz
=
0.01
),
pred
,
100
)
play_n_episodes
(
get_player
(
viz
=
0.01
),
pred
,
100
,
render
=
True
)
elif
args
.
task
==
'eval'
:
eval_model_multithread
(
pred
,
EVAL_EPISODE
,
get_player
)
else
:
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'DQN-{}'
.
format
(
os
.
path
.
basename
(
ENV_NAME
)
.
split
(
'.'
)[
0
])))
config
=
get_config
()
os
.
path
.
basename
(
args
.
env
)
.
split
(
'.'
)[
0
])))
config
=
get_config
(
model
)
if
args
.
load
:
config
.
session_init
=
get_model_loader
(
args
.
load
)
launch_train_with_config
(
config
,
SimpleTrainer
())
examples/DeepQNetwork/DQNModel.py
View file @
f417c49f
...
...
@@ -13,29 +13,29 @@ from tensorpack.utils import logger
class
Model
(
ModelDesc
):
learning_rate
=
1e-3
state_dtype
=
tf
.
uint8
def
__init__
(
self
,
state_shape
,
history
,
method
,
num_actions
,
gamma
):
# reward discount factor
gamma
=
0.99
def
__init__
(
self
,
state_shape
,
history
,
method
,
num_actions
):
"""
Args:
state_shape (tuple[int]),
history (int):
"""
self
.
_
state_shape
=
tuple
(
state_shape
)
self
.
_stacked_state_shape
=
(
-
1
,
)
+
self
.
_
state_shape
+
(
history
,
)
self
.
state_shape
=
tuple
(
state_shape
)
self
.
_stacked_state_shape
=
(
-
1
,
)
+
self
.
state_shape
+
(
history
,
)
self
.
history
=
history
self
.
method
=
method
self
.
num_actions
=
num_actions
self
.
gamma
=
gamma
def
inputs
(
self
):
# When we use h history frames, the current state and the next state will have (h-1) overlapping frames.
# Therefore we use a combined state for efficiency:
# The first h are the current state, and the last h are the next state.
return
[
tf
.
placeholder
(
self
.
state_dtype
,
(
None
,)
+
self
.
_
state_shape
+
(
self
.
history
+
1
,
),
(
None
,)
+
self
.
state_shape
+
(
self
.
history
+
1
,
),
'comb_state'
),
tf
.
placeholder
(
tf
.
int64
,
(
None
,),
'action'
),
tf
.
placeholder
(
tf
.
float32
,
(
None
,),
'reward'
),
...
...
@@ -101,7 +101,7 @@ class Model(ModelDesc):
return
cost
def
optimizer
(
self
):
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
self
.
learning_rate
,
trainable
=
False
)
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
1e-3
,
trainable
=
False
)
opt
=
tf
.
train
.
RMSPropOptimizer
(
lr
,
epsilon
=
1e-5
)
return
optimizer
.
apply_grad_processors
(
opt
,
[
gradproc
.
SummaryGradient
()])
...
...
examples/DeepQNetwork/atari.py
View file @
f417c49f
...
...
@@ -140,6 +140,9 @@ class AtariPlayer(gym.Env):
self
.
_restart_episode
()
return
self
.
_current_state
()
def
render
(
self
,
*
args
,
**
kwargs
):
pass
# visualization for this env is through the viz= argument when creating the player
def
step
(
self
,
act
):
oldlives
=
self
.
ale
.
lives
()
r
=
0
...
...
examples/DeepQNetwork/expreplay.py
View file @
f417c49f
...
...
@@ -22,22 +22,24 @@ Experience = namedtuple('Experience',
class
ReplayMemory
(
object
):
def
__init__
(
self
,
max_size
,
state_shape
,
history_len
):
def
__init__
(
self
,
max_size
,
state_shape
,
history_len
,
dtype
=
'uint8'
):
"""
Args:
state_shape (tuple[int]): shape (without history) of state
dtype: numpy dtype for the state
"""
self
.
max_size
=
int
(
max_size
)
self
.
state_shape
=
state_shape
assert
len
(
state_shape
)
in
[
1
,
2
,
3
],
state_shape
self
.
_output_shape
=
self
.
state_shape
+
(
history_len
+
1
,
)
self
.
history_len
=
int
(
history_len
)
self
.
dtype
=
dtype
all_state_shape
=
(
self
.
max_size
,)
+
state_shape
logger
.
info
(
"Creating experience replay buffer of {:.1f} GB ... "
"use a smaller buffer if you don't have enough CPU memory."
.
format
(
np
.
prod
(
all_state_shape
)
/
1024.0
**
3
))
self
.
state
=
np
.
zeros
(
all_state_shape
,
dtype
=
'uint8'
)
self
.
state
=
np
.
zeros
(
all_state_shape
,
dtype
=
self
.
dtype
)
self
.
action
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'int32'
)
self
.
reward
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'float32'
)
self
.
isOver
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'bool'
)
...
...
@@ -66,7 +68,7 @@ class ReplayMemory(object):
def
recent_state
(
self
):
""" return a list of ``hist_len-1`` elements, each of shape ``self.state_shape`` """
lst
=
list
(
self
.
_hist
)
states
=
[
np
.
zeros
(
self
.
state_shape
,
dtype
=
'uint8'
)]
*
(
self
.
_hist
.
maxlen
-
len
(
lst
))
states
=
[
np
.
zeros
(
self
.
state_shape
,
dtype
=
self
.
dtype
)]
*
(
self
.
_hist
.
maxlen
-
len
(
lst
))
states
.
extend
([
k
.
state
for
k
in
lst
])
return
states
...
...
@@ -137,7 +139,8 @@ class ExpReplay(DataFlow, Callback):
batch_size
,
memory_size
,
init_memory_size
,
init_exploration
,
update_frequency
,
history_len
):
update_frequency
,
history_len
,
state_dtype
=
'uint8'
):
"""
Args:
predictor_io_names (tuple of list of str): input/output names to
...
...
@@ -219,11 +222,14 @@ class ExpReplay(DataFlow, Callback):
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
self
.
_current_game_score
.
feed
(
reward
)
if
isOver
:
# handle ale-specific information
if
info
.
get
(
'ale.lives'
,
-
1
)
==
0
:
if
'ale.lives'
in
info
:
# if running Atari, do something special for logging:
if
info
[
'ale.lives'
]
==
0
:
# only record score when a whole game is over (not when an episode is over)
self
.
_player_scores
.
feed
(
self
.
_current_game_score
.
sum
)
self
.
_current_game_score
.
reset
()
else
:
self
.
_player_scores
.
feed
(
self
.
_current_game_score
.
sum
)
self
.
_current_game_score
.
reset
()
self
.
player
.
reset
()
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
...
...
@@ -244,7 +250,7 @@ class ExpReplay(DataFlow, Callback):
view_state
(
sample
[
0
])
def
_process_batch
(
self
,
batch_exp
):
state
=
np
.
asarray
([
e
[
0
]
for
e
in
batch_exp
],
dtype
=
'uint8'
)
state
=
np
.
asarray
([
e
[
0
]
for
e
in
batch_exp
],
dtype
=
self
.
state_dtype
)
reward
=
np
.
asarray
([
e
[
1
]
for
e
in
batch_exp
],
dtype
=
'float32'
)
action
=
np
.
asarray
([
e
[
2
]
for
e
in
batch_exp
],
dtype
=
'int8'
)
isOver
=
np
.
asarray
([
e
[
3
]
for
e
in
batch_exp
],
dtype
=
'bool'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment