Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7e963996
Commit
7e963996
authored
Oct 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
deprecated tensorpack.RL and use gym for RL examples
parent
c270a1ed
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
267 additions
and
165 deletions
+267
-165
CHANGES.md
CHANGES.md
+2
-0
examples/A3C-Gym/simulator.py
examples/A3C-Gym/simulator.py
+10
-8
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+22
-19
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+15
-18
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+54
-82
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+142
-24
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+19
-14
tensorpack/RL/__init__.py
tensorpack/RL/__init__.py
+3
-0
No files found.
CHANGES.md
View file @
7e963996
...
...
@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+
[
2017/10/10
](
https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc
)
.
`tfutils.distributions`
was deprecated in favor of
`tf.distributions`
introduced in TF 1.3.
+
[
2017/08/02
](
https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465
)
.
`Trainer.get_predictor`
now takes GPU id. And
`Trainer.get_predictors`
was deprecated.
+
2017/06/07. Now the library explicitly depends on msgpack-numpy>=0.3.9. The serialization protocol
...
...
examples/A3C-Gym/simulator.py
View file @
7e963996
...
...
@@ -78,18 +78,21 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
# s2c_socket.set_hwm(5)
s2c_socket
.
connect
(
self
.
s2c
)
state
=
player
.
current_state
()
state
=
player
.
reset
()
reward
,
isOver
=
0
,
False
while
True
:
# after taking the last action, get to this state and get this reward/isOver.
# If isOver, get to the next-episode state immediately.
# This tuple is not the same as the one put into the memory buffer
c2s_socket
.
send
(
dumps
(
(
self
.
identity
,
state
,
reward
,
isOver
)),
copy
=
False
)
action
=
loads
(
s2c_socket
.
recv
(
copy
=
False
)
.
bytes
)
reward
,
isOver
=
player
.
action
(
action
)
state
=
player
.
current_state
()
state
,
reward
,
isOver
,
_
=
player
.
step
(
action
)
if
isOver
:
state
=
player
.
reset
()
# compatibility
...
...
@@ -180,17 +183,16 @@ class SimulatorMaster(threading.Thread):
if
__name__
==
'__main__'
:
import
random
from
tensorpack.RL
import
NaiveRLEnvironment
import
gym
class
NaiveSimulator
(
SimulatorProcess
):
def
_build_player
(
self
):
return
NaiveRLEnvironment
(
)
return
gym
.
make
(
'Breakout-v0'
)
class
NaiveActioner
(
SimulatorMaster
):
def
_get_action
(
self
,
state
):
time
.
sleep
(
1
)
return
random
.
randint
(
1
,
12
)
return
random
.
randint
(
1
,
3
)
def
_on_episode_over
(
self
,
client
):
# print("Over: ", client.memory)
...
...
examples/A3C-Gym/train-atari.py
View file @
7e963996
...
...
@@ -27,11 +27,12 @@ from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
from
tensorpack.utils.gpu
import
get_nr_gpu
from
tensorpack.RL
import
*
import
gym
from
simulator
import
*
import
common
from
common
import
(
play_model
,
Evaluator
,
eval_model_multithread
,
play_one_episode
,
play_n_episodes
)
from
common
import
(
Evaluator
,
eval_model_multithread
,
play_one_episode
,
play_n_episodes
,
WarpFrame
,
FrameStack
,
FireResetEnv
,
LimitLength
)
if
six
.
PY3
:
from
concurrent
import
futures
...
...
@@ -58,15 +59,16 @@ NUM_ACTIONS = None
ENV_NAME
=
None
def
get_player
(
viz
=
False
,
train
=
False
,
dumpdir
=
None
):
pl
=
GymEnv
(
ENV_NAME
,
viz
=
viz
,
dumpdir
=
dumpdir
)
pl
=
MapPlayerState
(
pl
,
lambda
img
:
cv2
.
resize
(
img
,
IMAGE_SIZE
[::
-
1
]))
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
if
not
train
:
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
else
:
pl
=
LimitLengthPlayer
(
pl
,
60000
)
return
pl
def
get_player
(
train
=
False
,
dumpdir
=
None
):
env
=
gym
.
make
(
ENV_NAME
)
if
dumpdir
:
env
=
gym
.
wrappers
.
Monitor
(
env
,
dumpdir
)
env
=
FireResetEnv
(
env
)
env
=
WarpFrame
(
env
,
IMAGE_SIZE
)
env
=
FrameStack
(
env
,
4
)
if
train
:
env
=
LimitLength
(
env
,
60000
)
return
env
class
MySimulatorWorker
(
SimulatorProcess
):
...
...
@@ -272,7 +274,7 @@ if __name__ == '__main__':
ENV_NAME
=
args
.
env
logger
.
info
(
"Environment Name: {}"
.
format
(
ENV_NAME
))
NUM_ACTIONS
=
get_player
()
.
get_action_space
()
.
num_actions
()
NUM_ACTIONS
=
get_player
()
.
action_space
.
n
logger
.
info
(
"Number of actions: {}"
.
format
(
NUM_ACTIONS
))
if
args
.
gpu
:
...
...
@@ -280,20 +282,21 @@ if __name__ == '__main__':
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
cfg
=
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
session_init
=
get_model_loader
(
args
.
load
),
input_names
=
[
'state'
],
output_names
=
[
'policy'
])
output_names
=
[
'policy'
])
)
if
args
.
task
==
'play'
:
play_model
(
cfg
,
get_player
(
viz
=
0.01
))
play_n_episodes
(
get_player
(
train
=
False
),
pred
,
args
.
episode
,
render
=
True
)
elif
args
.
task
==
'eval'
:
eval_model_multithread
(
cfg
,
args
.
episode
,
get_player
)
eval_model_multithread
(
pred
,
args
.
episode
,
get_player
)
elif
args
.
task
==
'gen_submit'
:
play_n_episodes
(
get_player
(
train
=
False
,
dumpdir
=
args
.
output
),
OfflinePredictor
(
cfg
)
,
args
.
episode
)
# gym.upload(output, api_key='xxx')
pred
,
args
.
episode
)
# gym.upload(
args.
output, api_key='xxx')
else
:
dirname
=
os
.
path
.
join
(
'train_log'
,
'train-atari-{}'
.
format
(
ENV_NAME
))
logger
.
set_logger_dir
(
dirname
)
...
...
examples/DeepQNetwork/DQN.py
View file @
7e963996
...
...
@@ -18,14 +18,14 @@ from collections import deque
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.RL
import
*
import
tensorflow
as
tf
from
DQNModel
import
Model
as
DQNModel
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
atari
import
AtariPlayer
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
common
import
FrameStack
,
WarpFrame
,
FireResetEnv
from
expreplay
import
ExpReplay
from
atari
import
AtariPlayer
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
...
...
@@ -37,7 +37,7 @@ GAMMA = 0.99
MEMORY_SIZE
=
1e6
# will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE
=
5e4
INIT_MEMORY_SIZE
=
MEMORY_SIZE
//
20
STEPS_PER_EPOCH
=
10000
//
UPDATE_FREQ
*
10
# each epoch is 100k played frames
EVAL_EPISODE
=
50
...
...
@@ -47,17 +47,14 @@ METHOD = None
def
get_player
(
viz
=
False
,
train
=
False
):
pl
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
live_lost_as_eoe
=
train
)
env
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
live_lost_as_eoe
=
train
,
max_num_frames
=
30000
)
env
=
FireResetEnv
(
env
)
env
=
WarpFrame
(
env
,
IMAGE_SIZE
)
if
not
train
:
# create a new axis to stack history on
pl
=
MapPlayerState
(
pl
,
lambda
im
:
im
[:,
:,
np
.
newaxis
])
# in training, history is taken care of in expreplay buffer
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
LimitLengthPlayer
(
pl
,
30000
)
return
pl
env
=
FrameStack
(
env
,
FRAME_HISTORY
)
return
env
class
Model
(
DQNModel
):
...
...
@@ -149,20 +146,20 @@ if __name__ == '__main__':
ROM_FILE
=
args
.
rom
METHOD
=
args
.
algo
# set num_actions
NUM_ACTIONS
=
AtariPlayer
(
ROM_FILE
)
.
get_action_space
()
.
num_actions
()
NUM_ACTIONS
=
AtariPlayer
(
ROM_FILE
)
.
action_space
.
n
logger
.
info
(
"ROM: {}, Num Actions: {}"
.
format
(
ROM_FILE
,
NUM_ACTIONS
))
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
cfg
=
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
session_init
=
get_model_loader
(
args
.
load
),
input_names
=
[
'state'
],
output_names
=
[
'Qvalue'
])
output_names
=
[
'Qvalue'
])
)
if
args
.
task
==
'play'
:
play_
model
(
cfg
,
get_player
(
viz
=
0.01
)
)
play_
n_episodes
(
get_player
(
viz
=
0.01
),
pred
,
100
)
elif
args
.
task
==
'eval'
:
eval_model_multithread
(
cfg
,
EVAL_EPISODE
,
get_player
)
eval_model_multithread
(
pred
,
EVAL_EPISODE
,
get_player
)
else
:
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'DQN-{}'
.
format
(
...
...
examples/DeepQNetwork/atari.py
View file @
7e963996
...
...
@@ -7,7 +7,6 @@ import numpy as np
import
time
import
os
import
cv2
from
collections
import
deque
import
threading
import
six
from
six.moves
import
range
...
...
@@ -16,7 +15,9 @@ from tensorpack.utils.utils import get_rng, execute_only_once
from
tensorpack.utils.fs
import
get_dataset_path
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.RL.envbase
import
RLEnvironment
,
DiscreteActionSpace
import
gym
from
gym
import
spaces
from
gym.envs.atari.atari_env
import
ACTION_MEANING
from
ale_python_interface
import
ALEInterface
...
...
@@ -26,27 +27,29 @@ ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK
=
threading
.
Lock
()
class
AtariPlayer
(
RLEnvironment
):
class
AtariPlayer
(
gym
.
Env
):
"""
A wrapper for atari emulator.
Will automatically restart when a real episode ends (isOver might be just
lost of lives but not game over).
A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.
Info:
score: the accumulated reward in the current game
gameOver: True when the current game is Over
"""
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
),
frame_skip
=
4
,
image_shape
=
(
84
,
84
),
nullop_start
=
30
,
live_lost_as_eoe
=
True
):
def
__init__
(
self
,
rom_file
,
viz
=
0
,
frame_skip
=
4
,
nullop_start
=
30
,
live_lost_as_eoe
=
True
,
max_num_frames
=
0
):
"""
:param rom_file: path to the rom
:param frame_skip: skip every k frames and repeat the action
:param image_shape: (w, h)
:param height_range: (h1, h2) to cut
:param viz: visualization to be don
e.
Set to 0 to disable
.
Set to a positive number to be the delay between frames to show
.
Set to a string to be a directory to store frame
s.
:param nullop_start: start with random number of null ops
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training
.
Args:
rom_file: path to the rom
frame_skip: skip every k frames and repeat the action
viz: visualization to be done.
Set to 0 to disabl
e.
Set to a positive number to be the delay between frames to show
.
Set to a string to be a directory to store frames
.
nullop_start: start with random number of null op
s.
live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
max_num_frames: maximum number of frames per episode
.
"""
super
(
AtariPlayer
,
self
)
.
__init__
()
if
not
os
.
path
.
isfile
(
rom_file
)
and
'/'
not
in
rom_file
:
...
...
@@ -65,6 +68,7 @@ class AtariPlayer(RLEnvironment):
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
b
"random_seed"
,
self
.
rng
.
randint
(
0
,
30000
))
self
.
ale
.
setInt
(
b
"max_num_frames_per_episode"
,
max_num_frames
)
self
.
ale
.
setBool
(
b
"showinfo"
,
False
)
self
.
ale
.
setInt
(
b
"frame_skip"
,
1
)
...
...
@@ -92,11 +96,16 @@ class AtariPlayer(RLEnvironment):
self
.
live_lost_as_eoe
=
live_lost_as_eoe
self
.
frame_skip
=
frame_skip
self
.
nullop_start
=
nullop_start
self
.
height_range
=
height_range
self
.
image_shape
=
image_shape
self
.
current_episode_score
=
StatCounter
()
self
.
restart_episode
()
self
.
action_space
=
spaces
.
Discrete
(
len
(
self
.
actions
))
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
self
.
height
,
self
.
width
))
self
.
_restart_episode
()
def
get_action_meanings
(
self
):
return
[
ACTION_MEANING
[
i
]
for
i
in
self
.
actions
]
def
_grab_raw_image
(
self
):
"""
...
...
@@ -105,7 +114,7 @@ class AtariPlayer(RLEnvironment):
m
=
self
.
ale
.
getScreenRGB
()
return
m
.
reshape
((
self
.
height
,
self
.
width
,
3
))
def
current_state
(
self
):
def
_
current_state
(
self
):
"""
:returns: a gray-scale (h, w) uint8 image
"""
...
...
@@ -116,19 +125,12 @@ class AtariPlayer(RLEnvironment):
if
isinstance
(
self
.
viz
,
float
):
cv2
.
imshow
(
self
.
windowname
,
ret
)
time
.
sleep
(
self
.
viz
)
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],
:]
.
astype
(
'float32'
)
ret
=
ret
.
astype
(
'float32'
)
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)
ret
=
cv2
.
resize
(
ret
,
self
.
image_shape
)
return
ret
.
astype
(
'uint8'
)
# to save some memory
def
get_action_space
(
self
):
return
DiscreteActionSpace
(
len
(
self
.
actions
))
def
finish_episode
(
self
):
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
def
restart_episode
(
self
):
def
_restart_episode
(
self
):
self
.
current_episode_score
.
reset
()
with
_ALE_LOCK
:
self
.
ale
.
reset_game
()
...
...
@@ -141,11 +143,12 @@ class AtariPlayer(RLEnvironment):
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
self
.
ale
.
act
(
0
)
def
action
(
self
,
act
):
"""
:param act: an index of the action
:returns: (reward, isOver)
"""
def
_reset
(
self
):
if
self
.
ale
.
game_over
():
self
.
_restart_episode
()
return
self
.
_current_state
()
def
_step
(
self
,
act
):
oldlives
=
self
.
ale
.
lives
()
r
=
0
for
k
in
range
(
self
.
frame_skip
):
...
...
@@ -158,55 +161,24 @@ class AtariPlayer(RLEnvironment):
break
self
.
current_episode_score
.
feed
(
r
)
isOver
=
self
.
ale
.
game_over
()
trueIsOver
=
isOver
=
self
.
ale
.
game_over
()
if
self
.
live_lost_as_eoe
:
isOver
=
isOver
or
newlives
<
oldlives
if
isOver
:
self
.
finish_episode
()
if
self
.
ale
.
game_over
():
self
.
restart_episode
()
return
(
r
,
isOver
)
info
=
{
'score'
:
self
.
current_episode_score
.
sum
,
'gameOver'
:
trueIsOver
}
return
self
.
_current_state
(),
r
,
isOver
,
info
if
__name__
==
'__main__'
:
import
sys
def
benchmark
():
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
False
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_action_space
()
.
num_actions
()
rng
=
get_rng
(
num
)
start
=
time
.
time
()
cnt
=
0
while
True
:
act
=
rng
.
choice
(
range
(
num
))
r
,
o
=
a
.
action
(
act
)
a
.
current_state
()
cnt
+=
1
if
cnt
==
5000
:
break
print
(
time
.
time
()
-
start
)
if
len
(
sys
.
argv
)
==
3
and
sys
.
argv
[
2
]
==
'benchmark'
:
import
threading
import
multiprocessing
for
k
in
range
(
3
):
# th = multiprocessing.Process(target=benchmark)
th
=
threading
.
Thread
(
target
=
benchmark
)
th
.
start
()
time
.
sleep
(
0.02
)
benchmark
()
else
:
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0.03
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_action_space
()
.
num_actions
()
rng
=
get_rng
(
num
)
import
time
while
True
:
# im = a.grab_image()
# cv2.imshow(a.romname, im)
act
=
rng
.
choice
(
range
(
num
))
print
(
act
)
r
,
o
=
a
.
action
(
act
)
a
.
current_state
()
# time.sleep(0.1)
print
(
r
,
o
)
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0.03
)
num
=
a
.
action_space
.
n
rng
=
get_rng
(
num
)
while
True
:
act
=
rng
.
choice
(
range
(
num
))
state
,
reward
,
isOver
,
info
=
a
.
step
(
act
)
if
isOver
:
print
(
info
)
a
.
reset
()
print
(
"Reward:"
,
reward
)
examples/DeepQNetwork/common.py
View file @
7e963996
...
...
@@ -7,35 +7,56 @@ import time
import
threading
import
multiprocessing
import
numpy
as
np
import
cv2
from
collections
import
deque
from
tqdm
import
tqdm
from
six.moves
import
queue
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.stats
import
*
import
gym
from
gym
import
spaces
from
tensorpack.utils.concurrency
import
StoppableThread
,
ShareSessionThread
from
tensorpack.callbacks
import
Triggerable
from
tensorpack.utils
import
logger
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.utils
import
get_tqdm_kwargs
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
def
f
(
s
):
spc
=
player
.
get_action_space
()
def
play_one_episode
(
env
,
func
,
render
=
False
):
def
predict
(
s
):
"""
Map from observation to action, with 0.001 greedy.
"""
act
=
func
([[
s
]])[
0
][
0
]
.
argmax
()
if
random
.
random
()
<
0.001
:
spc
=
env
.
action_space
act
=
spc
.
sample
()
if
verbose
:
print
(
act
)
return
act
return
np
.
mean
(
player
.
play_one_episode
(
f
))
def
play_model
(
cfg
,
player
):
predfunc
=
OfflinePredictor
(
cfg
)
ob
=
env
.
reset
()
sum_r
=
0
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"Total:"
,
score
)
act
=
predict
(
ob
)
ob
,
r
,
isOver
,
info
=
env
.
step
(
act
)
if
render
:
env
.
render
()
sum_r
+=
r
if
isOver
:
return
sum_r
def
play_n_episodes
(
player
,
predfunc
,
nr
,
render
=
False
):
logger
.
info
(
"Start Playing ... "
)
for
k
in
range
(
nr
):
score
=
play_one_episode
(
player
,
predfunc
,
render
=
render
)
print
(
"{}/{}, score={}"
.
format
(
k
,
nr
,
score
))
def
eval_with_funcs
(
predictors
,
nr_eval
,
get_player_fn
):
"""
Args:
predictors ([PredictorBase])
"""
class
Worker
(
StoppableThread
,
ShareSessionThread
):
def
__init__
(
self
,
func
,
queue
):
super
(
Worker
,
self
)
.
__init__
()
...
...
@@ -85,10 +106,14 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
return
(
0
,
0
)
def
eval_model_multithread
(
cfg
,
nr_eval
,
get_player_fn
):
func
=
OfflinePredictor
(
cfg
)
def
eval_model_multithread
(
pred
,
nr_eval
,
get_player_fn
):
"""
Args:
pred (OfflinePredictor): state -> Qvalue
"""
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
,
nr_eval
,
get_player_fn
)
with
pred
.
sess
.
as_default
():
mean
,
max
=
eval_with_funcs
([
pred
]
*
NR_PROC
,
nr_eval
,
get_player_fn
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
...
...
@@ -115,10 +140,103 @@ class Evaluator(Triggerable):
self
.
trainer
.
monitors
.
put_scalar
(
'max_score'
,
max
)
def
play_n_episodes
(
player
,
predfunc
,
nr
):
logger
.
info
(
"Start evaluation: "
)
for
k
in
range
(
nr
):
if
k
!=
0
:
player
.
restart_episode
()
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"{}/{}, score={}"
.
format
(
k
,
nr
,
score
))
"""
------------------------------------------------------------------------------
The following wrappers are copied or modified from openai/baselines:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
class
WarpFrame
(
gym
.
ObservationWrapper
):
def
__init__
(
self
,
env
,
shape
):
gym
.
ObservationWrapper
.
__init__
(
self
,
env
)
self
.
shape
=
shape
obs
=
env
.
observation_space
assert
isinstance
(
obs
,
spaces
.
Box
)
chan
=
1
if
len
(
obs
.
shape
)
==
2
else
obs
.
shape
[
2
]
shape3d
=
shape
if
chan
==
1
else
shape
+
(
chan
,)
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
shape3d
)
def
_observation
(
self
,
obs
):
return
cv2
.
resize
(
obs
,
self
.
shape
)
class
FrameStack
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
"""Buffer observations and stack across channels (last axis)."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
self
.
frames
=
deque
([],
maxlen
=
k
)
shp
=
env
.
observation_space
.
shape
chan
=
1
if
len
(
shp
)
==
2
else
shp
[
2
]
self
.
_base_dim
=
len
(
shp
)
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
shp
[
0
],
shp
[
1
],
chan
*
k
))
def
_reset
(
self
):
"""Clear buffer and re-fill by duplicating the first observation."""
ob
=
self
.
env
.
reset
()
for
_
in
range
(
self
.
k
-
1
):
self
.
frames
.
append
(
np
.
zeros_like
(
ob
))
self
.
frames
.
append
(
ob
)
return
self
.
_observation
()
def
_step
(
self
,
action
):
ob
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
frames
.
append
(
ob
)
return
self
.
_observation
(),
reward
,
done
,
info
def
_observation
(
self
):
assert
len
(
self
.
frames
)
==
self
.
k
if
self
.
_base_dim
==
2
:
return
np
.
stack
(
self
.
frames
,
axis
=-
1
)
else
:
return
np
.
concatenate
(
self
.
frames
,
axis
=
2
)
class
_FireResetEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
):
"""Take action on reset for environments that are fixed until firing."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
assert
env
.
unwrapped
.
get_action_meanings
()[
1
]
==
'FIRE'
assert
len
(
env
.
unwrapped
.
get_action_meanings
())
>=
3
def
_reset
(
self
):
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
1
)
if
done
:
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
2
)
if
done
:
self
.
env
.
reset
()
return
obs
def
FireResetEnv
(
env
):
if
isinstance
(
env
,
gym
.
Wrapper
):
baseenv
=
env
.
unwrapped
else
:
baseenv
=
env
if
'FIRE'
in
baseenv
.
get_action_meanings
():
return
_FireResetEnv
(
env
)
return
env
class
LimitLength
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
def
_reset
(
self
):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob
=
self
.
env
.
reset
()
self
.
cnt
=
0
return
ob
def
_step
(
self
,
action
):
ob
,
r
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
cnt
+=
1
if
self
.
cnt
==
self
.
k
:
done
=
True
return
ob
,
r
,
done
,
info
examples/DeepQNetwork/expreplay.py
View file @
7e963996
...
...
@@ -13,6 +13,7 @@ from six.moves import queue, range
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.utils
import
logger
from
tensorpack.utils.utils
import
get_tqdm
,
get_rng
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.concurrency
import
LoopThread
,
ShareSessionThread
from
tensorpack.callbacks.base
import
Callback
...
...
@@ -142,7 +143,7 @@ class ExpReplay(DataFlow, Callback):
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
self
.
exploration
=
init_exploration
self
.
num_actions
=
player
.
get_action_space
()
.
num_actions
()
self
.
num_actions
=
player
.
action_space
.
n
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
self
.
rng
=
get_rng
(
self
)
...
...
@@ -152,6 +153,8 @@ class ExpReplay(DataFlow, Callback):
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
5
)
self
.
mem
=
ReplayMemory
(
memory_size
,
state_shape
,
history_len
)
self
.
_current_ob
=
self
.
player
.
reset
()
self
.
_player_scores
=
StatCounter
()
def
get_simulator_thread
(
self
):
# spawn a separate thread to run policy
...
...
@@ -186,7 +189,7 @@ class ExpReplay(DataFlow, Callback):
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
old_s
=
self
.
player
.
current_state
()
old_s
=
self
.
_current_ob
if
self
.
rng
.
rand
()
<=
self
.
exploration
or
(
len
(
self
.
mem
)
<=
self
.
history_len
):
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
...
...
@@ -198,7 +201,11 @@ class ExpReplay(DataFlow, Callback):
# assume batched network
q_values
=
self
.
predictor
([[
history
]])[
0
][
0
]
# this is the bottleneck
act
=
np
.
argmax
(
q_values
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
if
isOver
:
if
info
[
'gameOver'
]:
# only record score when a whole game is over (not when an episode is over)
self
.
_player_scores
.
feed
(
info
[
'score'
])
self
.
player
.
reset
()
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
def
_debug_sample
(
self
,
sample
):
...
...
@@ -245,17 +252,15 @@ class ExpReplay(DataFlow, Callback):
self
.
_simulator_th
=
self
.
get_simulator_thread
()
self
.
_simulator_th
.
start
()
def
_trigger_epoch
(
self
):
# log player statistics in training
stats
=
self
.
player
.
stats
for
k
,
v
in
six
.
iteritems
(
stats
):
try
:
mean
,
max
=
np
.
mean
(
v
),
np
.
max
(
v
)
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/mean_'
+
k
,
mean
)
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/max_'
+
k
,
max
)
except
:
logger
.
exception
(
"Cannot log training scores."
)
self
.
player
.
reset_stat
()
def
_trigger
(
self
):
v
=
self
.
_player_scores
try
:
mean
,
max
=
v
.
average
,
v
.
max
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/mean_score'
,
mean
)
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/max_score'
,
max
)
except
:
logger
.
exception
(
"Cannot log training scores."
)
v
.
reset
()
if
__name__
==
'__main__'
:
...
...
tensorpack/RL/__init__.py
View file @
7e963996
...
...
@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
pkgutil
import
iter_modules
from
..utils.develop
import
log_deprecated
import
os
import
os.path
...
...
@@ -13,6 +14,8 @@ __all__ = []
This module should be removed in the future.
"""
log_deprecated
(
"tensorpack.RL"
,
"Please use gym or other APIs instead!"
,
"2017-12-31"
)
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment