Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ad5321a6
Commit
ad5321a6
authored
Nov 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RL] use more general MapState instead of WarpFrame
parent
2c429763
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
13 deletions
+9
-13
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+2
-2
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+3
-2
examples/DeepQNetwork/atari_wrapper.py
examples/DeepQNetwork/atari_wrapper.py
+4
-9
No files found.
examples/A3C-Gym/train-atari.py
View file @
ad5321a6
...
@@ -32,7 +32,7 @@ import gym
...
@@ -32,7 +32,7 @@ import gym
from
simulator
import
*
from
simulator
import
*
from
common
import
(
Evaluator
,
eval_model_multithread
,
from
common
import
(
Evaluator
,
eval_model_multithread
,
play_one_episode
,
play_n_episodes
)
play_one_episode
,
play_n_episodes
)
from
atari_wrapper
import
WarpFram
e
,
FrameStack
,
FireResetEnv
,
LimitLength
from
atari_wrapper
import
MapStat
e
,
FrameStack
,
FireResetEnv
,
LimitLength
if
six
.
PY3
:
if
six
.
PY3
:
from
concurrent
import
futures
from
concurrent
import
futures
...
@@ -64,7 +64,7 @@ def get_player(train=False, dumpdir=None):
...
@@ -64,7 +64,7 @@ def get_player(train=False, dumpdir=None):
if
dumpdir
:
if
dumpdir
:
env
=
gym
.
wrappers
.
Monitor
(
env
,
dumpdir
)
env
=
gym
.
wrappers
.
Monitor
(
env
,
dumpdir
)
env
=
FireResetEnv
(
env
)
env
=
FireResetEnv
(
env
)
env
=
WarpFrame
(
env
,
IMAGE_SIZE
)
env
=
MapState
(
env
,
lambda
im
:
cv2
.
resize
(
im
,
IMAGE_SIZE
)
)
env
=
FrameStack
(
env
,
4
)
env
=
FrameStack
(
env
,
4
)
if
train
:
if
train
:
env
=
LimitLength
(
env
,
60000
)
env
=
LimitLength
(
env
,
60000
)
...
...
examples/DeepQNetwork/DQN.py
View file @
ad5321a6
...
@@ -15,6 +15,7 @@ import subprocess
...
@@ -15,6 +15,7 @@ import subprocess
import
multiprocessing
import
multiprocessing
import
threading
import
threading
from
collections
import
deque
from
collections
import
deque
import
cv2
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack
import
*
...
@@ -23,7 +24,7 @@ import tensorflow as tf
...
@@ -23,7 +24,7 @@ import tensorflow as tf
from
DQNModel
import
Model
as
DQNModel
from
DQNModel
import
Model
as
DQNModel
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
atari_wrapper
import
FrameStack
,
WarpFram
e
,
FireResetEnv
from
atari_wrapper
import
FrameStack
,
MapStat
e
,
FireResetEnv
from
expreplay
import
ExpReplay
from
expreplay
import
ExpReplay
from
atari
import
AtariPlayer
from
atari
import
AtariPlayer
...
@@ -50,7 +51,7 @@ def get_player(viz=False, train=False):
...
@@ -50,7 +51,7 @@ def get_player(viz=False, train=False):
env
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
env
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
live_lost_as_eoe
=
train
,
max_num_frames
=
30000
)
live_lost_as_eoe
=
train
,
max_num_frames
=
30000
)
env
=
FireResetEnv
(
env
)
env
=
FireResetEnv
(
env
)
env
=
WarpFrame
(
env
,
IMAGE_SIZE
)
env
=
MapState
(
env
,
lambda
im
:
cv2
.
resize
(
im
,
IMAGE_SIZE
)
)
if
not
train
:
if
not
train
:
# in training, history is taken care of in expreplay buffer
# in training, history is taken care of in expreplay buffer
env
=
FrameStack
(
env
,
FRAME_HISTORY
)
env
=
FrameStack
(
env
,
FRAME_HISTORY
)
...
...
examples/DeepQNetwork/atari_wrapper.py
View file @
ad5321a6
...
@@ -16,18 +16,13 @@ https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.
...
@@ -16,18 +16,13 @@ https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.
"""
"""
class
WarpFram
e
(
gym
.
ObservationWrapper
):
class
MapStat
e
(
gym
.
ObservationWrapper
):
def
__init__
(
self
,
env
,
shape
):
def
__init__
(
self
,
env
,
map_func
):
gym
.
ObservationWrapper
.
__init__
(
self
,
env
)
gym
.
ObservationWrapper
.
__init__
(
self
,
env
)
self
.
shape
=
shape
self
.
_func
=
map_func
obs
=
env
.
observation_space
assert
isinstance
(
obs
,
spaces
.
Box
)
chan
=
1
if
len
(
obs
.
shape
)
==
2
else
obs
.
shape
[
2
]
shape3d
=
shape
if
chan
==
1
else
shape
+
(
chan
,)
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
shape3d
)
def
_observation
(
self
,
obs
):
def
_observation
(
self
,
obs
):
return
cv2
.
resize
(
obs
,
self
.
shape
)
return
self
.
_func
(
obs
)
class
FrameStack
(
gym
.
Wrapper
):
class
FrameStack
(
gym
.
Wrapper
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment