Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
634369d8
Commit
634369d8
authored
Oct 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
split atari_wrapper from common
parent
7e963996
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
120 additions
and
116 deletions
+120
-116
CHANGES.md
CHANGES.md
+2
-0
examples/A3C-Gym/README.md
examples/A3C-Gym/README.md
+0
-1
examples/A3C-Gym/atari_wrapper.py
examples/A3C-Gym/atari_wrapper.py
+1
-0
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+2
-3
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+1
-2
examples/DeepQNetwork/README.md
examples/DeepQNetwork/README.md
+3
-1
examples/DeepQNetwork/atari_wrapper.py
examples/DeepQNetwork/atari_wrapper.py
+111
-0
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+0
-109
No files found.
CHANGES.md
View file @
634369d8
...
...
@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+
[
2017/10/12
](
https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e
)
.
`tensorpack.RL`
was deprecated. The RL examples are written with OpenAI gym interface instead.
+
[
2017/10/10
](
https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc
)
.
`tfutils.distributions`
was deprecated in favor of
`tf.distributions`
introduced in TF 1.3.
+
[
2017/08/02
](
https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465
)
.
...
...
examples/A3C-Gym/README.md
View file @
634369d8
...
...
@@ -57,7 +57,6 @@ Models are available for the following atari environments (click to watch videos
|
[
VideoPinball
](
https://gym.openai.com/evaluations/eval_PWwzNhVFR2CxjYvEsPfT1g
)
|
[
WizardOfWor
](
https://gym.openai.com/evaluations/eval_1oGQhphpQhmzEMIYRrrp0A
)
|
[
Zaxxon
](
https://gym.openai.com/evaluations/eval_TIQ102EwTrHrOyve2RGfg
)
| |
Note that atari game settings in gym (AtariGames-v0) are quite different from DeepMind papers, so the scores are not comparable. The most notable differences are:
+
Each action is randomly repeated 2~4 times.
+
Inputs are RGB instead of greyscale.
...
...
examples/A3C-Gym/atari_wrapper.py
0 → 120000
View file @
634369d8
../
DeepQNetwork
/
atari_wrapper
.
py
\ No newline at end of file
examples/A3C-Gym/train-atari.py
View file @
634369d8
...
...
@@ -29,10 +29,9 @@ from tensorpack.utils.gpu import get_nr_gpu
import
gym
from
simulator
import
*
import
common
from
common
import
(
Evaluator
,
eval_model_multithread
,
play_one_episode
,
play_n_episodes
,
WarpFrame
,
FrameStack
,
FireResetEnv
,
LimitLength
)
play_one_episode
,
play_n_episodes
)
from
atari_wrapper
import
WarpFrame
,
FrameStack
,
FireResetEnv
,
LimitLength
if
six
.
PY3
:
from
concurrent
import
futures
...
...
examples/DeepQNetwork/DQN.py
View file @
634369d8
...
...
@@ -21,9 +21,8 @@ from tensorpack.utils.concurrency import *
import
tensorflow
as
tf
from
DQNModel
import
Model
as
DQNModel
import
common
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
common
import
FrameStack
,
WarpFrame
,
FireResetEnv
from
atari_wrapper
import
FrameStack
,
WarpFrame
,
FireResetEnv
from
expreplay
import
ExpReplay
from
atari
import
AtariPlayer
...
...
examples/DeepQNetwork/README.md
View file @
634369d8
...
...
@@ -27,6 +27,8 @@ Double-DQN runs at 60 batches (3840 trained frames, 240 seen frames, 960 game fr
## How to use
Install
[
ALE
](
https://github.com/mgbellemare/Arcade-Learning-Environment
)
and gym.
Download an
[
atari rom
](
https://github.com/openai/atari-py/tree/master/atari_py/atari_roms
)
to
`$TENSORPACK_DATASET/atari_rom/`
(defaults to ~/tensorpack_data/atari_rom/), e.g.:
```
...
...
@@ -42,7 +44,7 @@ Start Training:
Watch the agent play:
```
./DQN.py --rom breakout.bin --task play --load
trained.
model
./DQN.py --rom breakout.bin --task play --load
path/to/
model
```
A pretrained model on breakout can be downloaded
[
here
](
https://drive.google.com/open?id=0B9IPQTvr2BBkN1Jrei1xWW0yR28
)
.
...
...
examples/DeepQNetwork/atari_wrapper.py
0 → 100644
View file @
634369d8
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: atari_wrapper.py
import
numpy
as
np
import
cv2
from
collections
import
deque
import
gym
from
gym
import
spaces
"""
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class
WarpFrame
(
gym
.
ObservationWrapper
):
def
__init__
(
self
,
env
,
shape
):
gym
.
ObservationWrapper
.
__init__
(
self
,
env
)
self
.
shape
=
shape
obs
=
env
.
observation_space
assert
isinstance
(
obs
,
spaces
.
Box
)
chan
=
1
if
len
(
obs
.
shape
)
==
2
else
obs
.
shape
[
2
]
shape3d
=
shape
if
chan
==
1
else
shape
+
(
chan
,)
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
shape3d
)
def
_observation
(
self
,
obs
):
return
cv2
.
resize
(
obs
,
self
.
shape
)
class
FrameStack
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
"""Buffer observations and stack across channels (last axis)."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
self
.
frames
=
deque
([],
maxlen
=
k
)
shp
=
env
.
observation_space
.
shape
chan
=
1
if
len
(
shp
)
==
2
else
shp
[
2
]
self
.
_base_dim
=
len
(
shp
)
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
shp
[
0
],
shp
[
1
],
chan
*
k
))
def
_reset
(
self
):
"""Clear buffer and re-fill by duplicating the first observation."""
ob
=
self
.
env
.
reset
()
for
_
in
range
(
self
.
k
-
1
):
self
.
frames
.
append
(
np
.
zeros_like
(
ob
))
self
.
frames
.
append
(
ob
)
return
self
.
_observation
()
def
_step
(
self
,
action
):
ob
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
frames
.
append
(
ob
)
return
self
.
_observation
(),
reward
,
done
,
info
def
_observation
(
self
):
assert
len
(
self
.
frames
)
==
self
.
k
if
self
.
_base_dim
==
2
:
return
np
.
stack
(
self
.
frames
,
axis
=-
1
)
else
:
return
np
.
concatenate
(
self
.
frames
,
axis
=
2
)
class
_FireResetEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
):
"""Take action on reset for environments that are fixed until firing."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
assert
env
.
unwrapped
.
get_action_meanings
()[
1
]
==
'FIRE'
assert
len
(
env
.
unwrapped
.
get_action_meanings
())
>=
3
def
_reset
(
self
):
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
1
)
if
done
:
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
2
)
if
done
:
self
.
env
.
reset
()
return
obs
def
FireResetEnv
(
env
):
if
isinstance
(
env
,
gym
.
Wrapper
):
baseenv
=
env
.
unwrapped
else
:
baseenv
=
env
if
'FIRE'
in
baseenv
.
get_action_meanings
():
return
_FireResetEnv
(
env
)
return
env
class
LimitLength
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
def
_reset
(
self
):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob
=
self
.
env
.
reset
()
self
.
cnt
=
0
return
ob
def
_step
(
self
,
action
):
ob
,
r
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
cnt
+=
1
if
self
.
cnt
==
self
.
k
:
done
=
True
return
ob
,
r
,
done
,
info
examples/DeepQNetwork/common.py
View file @
634369d8
...
...
@@ -4,17 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
random
import
time
import
threading
import
multiprocessing
import
numpy
as
np
import
cv2
from
collections
import
deque
from
tqdm
import
tqdm
from
six.moves
import
queue
import
gym
from
gym
import
spaces
from
tensorpack.utils.concurrency
import
StoppableThread
,
ShareSessionThread
from
tensorpack.callbacks
import
Triggerable
from
tensorpack.utils
import
logger
...
...
@@ -138,105 +131,3 @@ class Evaluator(Triggerable):
self
.
eval_episode
=
int
(
self
.
eval_episode
*
0.94
)
self
.
trainer
.
monitors
.
put_scalar
(
'mean_score'
,
mean
)
self
.
trainer
.
monitors
.
put_scalar
(
'max_score'
,
max
)
"""
------------------------------------------------------------------------------
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class
WarpFrame
(
gym
.
ObservationWrapper
):
def
__init__
(
self
,
env
,
shape
):
gym
.
ObservationWrapper
.
__init__
(
self
,
env
)
self
.
shape
=
shape
obs
=
env
.
observation_space
assert
isinstance
(
obs
,
spaces
.
Box
)
chan
=
1
if
len
(
obs
.
shape
)
==
2
else
obs
.
shape
[
2
]
shape3d
=
shape
if
chan
==
1
else
shape
+
(
chan
,)
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
shape3d
)
def
_observation
(
self
,
obs
):
return
cv2
.
resize
(
obs
,
self
.
shape
)
class
FrameStack
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
"""Buffer observations and stack across channels (last axis)."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
self
.
frames
=
deque
([],
maxlen
=
k
)
shp
=
env
.
observation_space
.
shape
chan
=
1
if
len
(
shp
)
==
2
else
shp
[
2
]
self
.
_base_dim
=
len
(
shp
)
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
shp
[
0
],
shp
[
1
],
chan
*
k
))
def
_reset
(
self
):
"""Clear buffer and re-fill by duplicating the first observation."""
ob
=
self
.
env
.
reset
()
for
_
in
range
(
self
.
k
-
1
):
self
.
frames
.
append
(
np
.
zeros_like
(
ob
))
self
.
frames
.
append
(
ob
)
return
self
.
_observation
()
def
_step
(
self
,
action
):
ob
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
frames
.
append
(
ob
)
return
self
.
_observation
(),
reward
,
done
,
info
def
_observation
(
self
):
assert
len
(
self
.
frames
)
==
self
.
k
if
self
.
_base_dim
==
2
:
return
np
.
stack
(
self
.
frames
,
axis
=-
1
)
else
:
return
np
.
concatenate
(
self
.
frames
,
axis
=
2
)
class
_FireResetEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
):
"""Take action on reset for environments that are fixed until firing."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
assert
env
.
unwrapped
.
get_action_meanings
()[
1
]
==
'FIRE'
assert
len
(
env
.
unwrapped
.
get_action_meanings
())
>=
3
def
_reset
(
self
):
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
1
)
if
done
:
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
2
)
if
done
:
self
.
env
.
reset
()
return
obs
def
FireResetEnv
(
env
):
if
isinstance
(
env
,
gym
.
Wrapper
):
baseenv
=
env
.
unwrapped
else
:
baseenv
=
env
if
'FIRE'
in
baseenv
.
get_action_meanings
():
return
_FireResetEnv
(
env
)
return
env
class
LimitLength
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
def
_reset
(
self
):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob
=
self
.
env
.
reset
()
self
.
cnt
=
0
return
ob
def
_step
(
self
,
action
):
ob
,
r
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
cnt
+=
1
if
self
.
cnt
==
self
.
k
:
done
=
True
return
ob
,
r
,
done
,
info
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment