Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0dbcbac7
Commit
0dbcbac7
authored
Sep 17, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
DQN supports gym as well.
parent
3aab66f1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
20 deletions
+48
-20
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+32
-13
examples/DeepQNetwork/README.md
examples/DeepQNetwork/README.md
+11
-2
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+3
-3
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+1
-1
examples/DoReFa-Net/dorefa.py
examples/DoReFa-Net/dorefa.py
+1
-1
No files found.
examples/DeepQNetwork/DQN.py
View file @
0dbcbac7
...
@@ -8,18 +8,19 @@ import argparse
...
@@ -8,18 +8,19 @@ import argparse
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
gym
from
tensorpack
import
*
from
tensorpack
import
*
from
DQNModel
import
Model
as
DQNModel
from
DQNModel
import
Model
as
DQNModel
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
atari_wrapper
import
FrameStack
,
MapState
,
FireResetEnv
from
atari_wrapper
import
FrameStack
,
MapState
,
FireResetEnv
,
LimitLength
from
expreplay
import
ExpReplay
from
expreplay
import
ExpReplay
from
atari
import
AtariPlayer
from
atari
import
AtariPlayer
BATCH_SIZE
=
64
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_CHANNEL
=
None
# 3 in gym and 1 in our own wrapper
FRAME_HISTORY
=
4
FRAME_HISTORY
=
4
ACTION_REPEAT
=
4
# aka FRAME_SKIP
ACTION_REPEAT
=
4
# aka FRAME_SKIP
UPDATE_FREQ
=
4
UPDATE_FREQ
=
4
...
@@ -33,24 +34,39 @@ STEPS_PER_EPOCH = 100000 // UPDATE_FREQ # each epoch is 100k played frames
...
@@ -33,24 +34,39 @@ STEPS_PER_EPOCH = 100000 // UPDATE_FREQ # each epoch is 100k played frames
EVAL_EPISODE
=
50
EVAL_EPISODE
=
50
NUM_ACTIONS
=
None
NUM_ACTIONS
=
None
ROM_FILE
=
None
USE_GYM
=
False
ENV_NAME
=
None
METHOD
=
None
METHOD
=
None
def
resize_keepdims
(
im
,
size
):
# Opencv's resize remove the extra dimension for grayscale images.
# We add it back.
ret
=
cv2
.
resize
(
im
,
size
)
if
im
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
return
ret
def
get_player
(
viz
=
False
,
train
=
False
):
def
get_player
(
viz
=
False
,
train
=
False
):
env
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
if
USE_GYM
:
env
=
gym
.
make
(
ENV_NAME
)
else
:
env
=
AtariPlayer
(
ENV_NAME
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
live_lost_as_eoe
=
train
,
max_num_frames
=
60000
)
live_lost_as_eoe
=
train
,
max_num_frames
=
60000
)
env
=
FireResetEnv
(
env
)
env
=
FireResetEnv
(
env
)
env
=
MapState
(
env
,
lambda
im
:
cv2
.
resize
(
im
,
IMAGE_SIZE
)[:,
:,
np
.
newaxis
]
)
env
=
MapState
(
env
,
lambda
im
:
resize_keepdims
(
im
,
IMAGE_SIZE
)
)
if
not
train
:
if
not
train
:
# in training, history is taken care of in expreplay buffer
# in training, history is taken care of in expreplay buffer
env
=
FrameStack
(
env
,
FRAME_HISTORY
)
env
=
FrameStack
(
env
,
FRAME_HISTORY
)
if
train
and
USE_GYM
:
env
=
LimitLength
(
env
,
60000
)
return
env
return
env
class
Model
(
DQNModel
):
class
Model
(
DQNModel
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
)
.
__init__
(
IMAGE_SIZE
,
1
,
FRAME_HISTORY
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
super
(
Model
,
self
)
.
__init__
(
IMAGE_SIZE
,
IMAGE_CHANNEL
,
FRAME_HISTORY
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
def
_get_DQN_prediction
(
self
,
image
):
def
_get_DQN_prediction
(
self
,
image
):
image
=
image
/
255.0
image
=
image
/
255.0
...
@@ -86,7 +102,7 @@ def get_config():
...
@@ -86,7 +102,7 @@ def get_config():
expreplay
=
ExpReplay
(
expreplay
=
ExpReplay
(
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
player
=
get_player
(
train
=
True
),
player
=
get_player
(
train
=
True
),
state_shape
=
IMAGE_SIZE
+
(
1
,),
state_shape
=
IMAGE_SIZE
+
(
IMAGE_CHANNEL
,),
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
...
@@ -126,18 +142,21 @@ if __name__ == '__main__':
...
@@ -126,18 +142,21 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--task'
,
help
=
'task to perform'
,
parser
.
add_argument
(
'--task'
,
help
=
'task to perform'
,
choices
=
[
'play'
,
'eval'
,
'train'
],
default
=
'train'
)
choices
=
[
'play'
,
'eval'
,
'train'
],
default
=
'train'
)
parser
.
add_argument
(
'--rom'
,
help
=
'atari rom'
,
required
=
True
)
parser
.
add_argument
(
'--env'
,
required
=
True
,
help
=
'either an atari rom file (that ends with .bin) or a gym atari environment name'
)
parser
.
add_argument
(
'--algo'
,
help
=
'algorithm'
,
parser
.
add_argument
(
'--algo'
,
help
=
'algorithm'
,
choices
=
[
'DQN'
,
'Double'
,
'Dueling'
],
default
=
'Double'
)
choices
=
[
'DQN'
,
'Double'
,
'Dueling'
],
default
=
'Double'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
ROM_FILE
=
args
.
rom
ENV_NAME
=
args
.
env
USE_GYM
=
not
ENV_NAME
.
endswith
(
'.bin'
)
IMAGE_CHANNEL
=
3
if
USE_GYM
else
1
METHOD
=
args
.
algo
METHOD
=
args
.
algo
# set num_actions
# set num_actions
NUM_ACTIONS
=
AtariPlayer
(
ROM_FILE
)
.
action_space
.
n
NUM_ACTIONS
=
get_player
(
)
.
action_space
.
n
logger
.
info
(
"
ROM: {}, Num Actions: {}"
.
format
(
ROM_FIL
E
,
NUM_ACTIONS
))
logger
.
info
(
"
ENV: {}, Num Actions: {}"
.
format
(
ENV_NAM
E
,
NUM_ACTIONS
))
if
args
.
task
!=
'train'
:
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
assert
args
.
load
is
not
None
...
@@ -153,7 +172,7 @@ if __name__ == '__main__':
...
@@ -153,7 +172,7 @@ if __name__ == '__main__':
else
:
else
:
logger
.
set_logger_dir
(
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'DQN-{}'
.
format
(
os
.
path
.
join
(
'train_log'
,
'DQN-{}'
.
format
(
os
.
path
.
basename
(
ROM_FIL
E
)
.
split
(
'.'
)[
0
])))
os
.
path
.
basename
(
ENV_NAM
E
)
.
split
(
'.'
)[
0
])))
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
get_model_loader
(
args
.
load
)
config
.
session_init
=
get_model_loader
(
args
.
load
)
...
...
examples/DeepQNetwork/README.md
View file @
0dbcbac7
...
@@ -26,6 +26,7 @@ Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 24
...
@@ -26,6 +26,7 @@ Double-DQN with nature paper setting runs at 60 batches (3840 trained frames, 24
## How to use
## How to use
### With ALE (paper's setting):
Install
[
ALE
](
https://github.com/mgbellemare/Arcade-Learning-Environment
)
and gym.
Install
[
ALE
](
https://github.com/mgbellemare/Arcade-Learning-Environment
)
and gym.
Download an
[
atari rom
](
https://github.com/openai/atari-py/tree/master/atari_py/atari_roms
)
, e.g.:
Download an
[
atari rom
](
https://github.com/openai/atari-py/tree/master/atari_py/atari_roms
)
, e.g.:
...
@@ -35,7 +36,7 @@ wget https://github.com/openai/atari-py/raw/master/atari_py/atari_roms/breakout.
...
@@ -35,7 +36,7 @@ wget https://github.com/openai/atari-py/raw/master/atari_py/atari_roms/breakout.
Start Training:
Start Training:
```
```
./DQN.py --
rom
breakout.bin
./DQN.py --
env
breakout.bin
# use `--algo` to select other DQN algorithms. See `-h` for more options.
# use `--algo` to select other DQN algorithms. See `-h` for more options.
```
```
...
@@ -43,7 +44,15 @@ Watch the agent play:
...
@@ -43,7 +44,15 @@ Watch the agent play:
```
```
# Download pretrained models or use one you trained:
# Download pretrained models or use one you trained:
wget http://models.tensorpack.com/DeepQNetwork/DoubleDQN-Breakout.npz
wget http://models.tensorpack.com/DeepQNetwork/DoubleDQN-Breakout.npz
./DQN.py --rom breakout.bin --task play --load DoubleDQN-Breakout.npz
./DQN.py --env breakout.bin --task play --load DoubleDQN-Breakout.npz
```
### With gym's Atari:
Install gym and atari_py.
```
./DQN.py --env BreakoutDeterministic-v4
```
```
A3C code and models for Atari games in OpenAI Gym are released in
[
examples/A3C-Gym
](
../A3C-Gym
)
A3C code and models for Atari games in OpenAI Gym are released in
[
examples/A3C-Gym
](
../A3C-Gym
)
examples/DeepQNetwork/atari.py
View file @
0dbcbac7
...
@@ -95,7 +95,7 @@ class AtariPlayer(gym.Env):
...
@@ -95,7 +95,7 @@ class AtariPlayer(gym.Env):
self
.
action_space
=
spaces
.
Discrete
(
len
(
self
.
actions
))
self
.
action_space
=
spaces
.
Discrete
(
len
(
self
.
actions
))
self
.
observation_space
=
spaces
.
Box
(
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
self
.
height
,
self
.
width
),
dtype
=
np
.
uint8
)
low
=
0
,
high
=
255
,
shape
=
(
self
.
height
,
self
.
width
,
1
),
dtype
=
np
.
uint8
)
self
.
_restart_episode
()
self
.
_restart_episode
()
def
get_action_meanings
(
self
):
def
get_action_meanings
(
self
):
...
@@ -110,7 +110,7 @@ class AtariPlayer(gym.Env):
...
@@ -110,7 +110,7 @@ class AtariPlayer(gym.Env):
def
_current_state
(
self
):
def
_current_state
(
self
):
"""
"""
:returns: a gray-scale (h, w) uint8 image
:returns: a gray-scale (h, w
, 1
) uint8 image
"""
"""
ret
=
self
.
_grab_raw_image
()
ret
=
self
.
_grab_raw_image
()
# max-pooled over the last screen
# max-pooled over the last screen
...
@@ -121,7 +121,7 @@ class AtariPlayer(gym.Env):
...
@@ -121,7 +121,7 @@ class AtariPlayer(gym.Env):
cv2
.
waitKey
(
int
(
self
.
viz
*
1000
))
cv2
.
waitKey
(
int
(
self
.
viz
*
1000
))
ret
=
ret
.
astype
(
'float32'
)
ret
=
ret
.
astype
(
'float32'
)
# 0.299,0.587.0.114. same as rgb2y in torch/image
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)
[:,
:,
np
.
newaxis
]
return
ret
.
astype
(
'uint8'
)
# to save some memory
return
ret
.
astype
(
'uint8'
)
# to save some memory
def
_restart_episode
(
self
):
def
_restart_episode
(
self
):
...
...
examples/DeepQNetwork/expreplay.py
View file @
0dbcbac7
...
@@ -135,7 +135,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -135,7 +135,7 @@ class ExpReplay(DataFlow, Callback):
Args:
Args:
predictor_io_names (tuple of list of str): input/output names to
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
predict Q value from state.
player (
RLEnvironment
): the player.
player (
gym.Env
): the player.
state_shape (tuple): h, w, c
state_shape (tuple): h, w, c
history_len (int): length of history frames to concat. Zero-filled
history_len (int): length of history frames to concat. Zero-filled
initial frames.
initial frames.
...
...
examples/DoReFa-Net/dorefa.py
View file @
0dbcbac7
...
@@ -30,7 +30,7 @@ def get_dorefa(bitW, bitA, bitG):
...
@@ -30,7 +30,7 @@ def get_dorefa(bitW, bitA, bitG):
@
tf
.
custom_gradient
@
tf
.
custom_gradient
def
_sign
(
x
):
def
_sign
(
x
):
return
tf
.
sign
(
x
/
E
)
*
E
,
lambda
dy
:
dy
return
tf
.
where
(
tf
.
equal
(
x
,
0
),
tf
.
ones_like
(
x
),
tf
.
sign
(
x
/
E
)
)
*
E
,
lambda
dy
:
dy
return
_sign
(
x
)
return
_sign
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment