Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7e963996
Commit
7e963996
authored
Oct 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
deprecated tensorpack.RL and use gym for RL examples
parent
c270a1ed
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
267 additions
and
165 deletions
+267
-165
CHANGES.md
CHANGES.md
+2
-0
examples/A3C-Gym/simulator.py
examples/A3C-Gym/simulator.py
+10
-8
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+22
-19
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+15
-18
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+54
-82
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+142
-24
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+19
-14
tensorpack/RL/__init__.py
tensorpack/RL/__init__.py
+3
-0
No files found.
CHANGES.md
View file @
7e963996
...
@@ -8,6 +8,8 @@ so you won't need to look at here very often.
...
@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+
[
2017/10/10
](
https://github.com/ppwwyyxx/tensorpack/commit/7d40e049691d92018f50dc7d45bba5e8b140becc
)
.
`tfutils.distributions`
was deprecated in favor of
`tf.distributions`
introduced in TF 1.3.
+
[
2017/08/02
](
https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465
)
.
+
[
2017/08/02
](
https://github.com/ppwwyyxx/tensorpack/commit/875f4d7dbb5675f54eae5675fa3a0948309a8465
)
.
`Trainer.get_predictor`
now takes GPU id. And
`Trainer.get_predictors`
was deprecated.
`Trainer.get_predictor`
now takes GPU id. And
`Trainer.get_predictors`
was deprecated.
+
2017/06/07. Now the library explicitly depends on msgpack-numpy>=0.3.9. The serialization protocol
+
2017/06/07. Now the library explicitly depends on msgpack-numpy>=0.3.9. The serialization protocol
...
...
examples/A3C-Gym/simulator.py
View file @
7e963996
...
@@ -78,18 +78,21 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
...
@@ -78,18 +78,21 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
# s2c_socket.set_hwm(5)
s2c_socket
.
connect
(
self
.
s2c
)
s2c_socket
.
connect
(
self
.
s2c
)
state
=
player
.
current_state
()
state
=
player
.
reset
()
reward
,
isOver
=
0
,
False
reward
,
isOver
=
0
,
False
while
True
:
while
True
:
# after taking the last action, get to this state and get this reward/isOver.
# If isOver, get to the next-episode state immediately.
# This tuple is not the same as the one put into the memory buffer
c2s_socket
.
send
(
dumps
(
c2s_socket
.
send
(
dumps
(
(
self
.
identity
,
state
,
reward
,
isOver
)),
(
self
.
identity
,
state
,
reward
,
isOver
)),
copy
=
False
)
copy
=
False
)
action
=
loads
(
s2c_socket
.
recv
(
copy
=
False
)
.
bytes
)
action
=
loads
(
s2c_socket
.
recv
(
copy
=
False
)
.
bytes
)
reward
,
isOver
=
player
.
action
(
action
)
state
,
reward
,
isOver
,
_
=
player
.
step
(
action
)
state
=
player
.
current_state
()
if
isOver
:
state
=
player
.
reset
()
# compatibility
# compatibility
...
@@ -180,17 +183,16 @@ class SimulatorMaster(threading.Thread):
...
@@ -180,17 +183,16 @@ class SimulatorMaster(threading.Thread):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
random
import
random
from
tensorpack.RL
import
NaiveRLEnvironment
import
gym
class
NaiveSimulator
(
SimulatorProcess
):
class
NaiveSimulator
(
SimulatorProcess
):
def
_build_player
(
self
):
def
_build_player
(
self
):
return
NaiveRLEnvironment
(
)
return
gym
.
make
(
'Breakout-v0'
)
class
NaiveActioner
(
SimulatorMaster
):
class
NaiveActioner
(
SimulatorMaster
):
def
_get_action
(
self
,
state
):
def
_get_action
(
self
,
state
):
time
.
sleep
(
1
)
time
.
sleep
(
1
)
return
random
.
randint
(
1
,
12
)
return
random
.
randint
(
1
,
3
)
def
_on_episode_over
(
self
,
client
):
def
_on_episode_over
(
self
,
client
):
# print("Over: ", client.memory)
# print("Over: ", client.memory)
...
...
examples/A3C-Gym/train-atari.py
View file @
7e963996
...
@@ -27,11 +27,12 @@ from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
...
@@ -27,11 +27,12 @@ from tensorpack.tfutils.gradproc import MapGradient, SummaryGradient
from
tensorpack.utils.gpu
import
get_nr_gpu
from
tensorpack.utils.gpu
import
get_nr_gpu
from
tensorpack.RL
import
*
import
gym
from
simulator
import
*
from
simulator
import
*
import
common
import
common
from
common
import
(
play_model
,
Evaluator
,
eval_model_multithread
,
from
common
import
(
Evaluator
,
eval_model_multithread
,
play_one_episode
,
play_n_episodes
)
play_one_episode
,
play_n_episodes
,
WarpFrame
,
FrameStack
,
FireResetEnv
,
LimitLength
)
if
six
.
PY3
:
if
six
.
PY3
:
from
concurrent
import
futures
from
concurrent
import
futures
...
@@ -58,15 +59,16 @@ NUM_ACTIONS = None
...
@@ -58,15 +59,16 @@ NUM_ACTIONS = None
ENV_NAME
=
None
ENV_NAME
=
None
def
get_player
(
viz
=
False
,
train
=
False
,
dumpdir
=
None
):
def
get_player
(
train
=
False
,
dumpdir
=
None
):
pl
=
GymEnv
(
ENV_NAME
,
viz
=
viz
,
dumpdir
=
dumpdir
)
env
=
gym
.
make
(
ENV_NAME
)
pl
=
MapPlayerState
(
pl
,
lambda
img
:
cv2
.
resize
(
img
,
IMAGE_SIZE
[::
-
1
]))
if
dumpdir
:
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
env
=
gym
.
wrappers
.
Monitor
(
env
,
dumpdir
)
if
not
train
:
env
=
FireResetEnv
(
env
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
env
=
WarpFrame
(
env
,
IMAGE_SIZE
)
else
:
env
=
FrameStack
(
env
,
4
)
pl
=
LimitLengthPlayer
(
pl
,
60000
)
if
train
:
return
pl
env
=
LimitLength
(
env
,
60000
)
return
env
class
MySimulatorWorker
(
SimulatorProcess
):
class
MySimulatorWorker
(
SimulatorProcess
):
...
@@ -272,7 +274,7 @@ if __name__ == '__main__':
...
@@ -272,7 +274,7 @@ if __name__ == '__main__':
ENV_NAME
=
args
.
env
ENV_NAME
=
args
.
env
logger
.
info
(
"Environment Name: {}"
.
format
(
ENV_NAME
))
logger
.
info
(
"Environment Name: {}"
.
format
(
ENV_NAME
))
NUM_ACTIONS
=
get_player
()
.
get_action_space
()
.
num_actions
()
NUM_ACTIONS
=
get_player
()
.
action_space
.
n
logger
.
info
(
"Number of actions: {}"
.
format
(
NUM_ACTIONS
))
logger
.
info
(
"Number of actions: {}"
.
format
(
NUM_ACTIONS
))
if
args
.
gpu
:
if
args
.
gpu
:
...
@@ -280,20 +282,21 @@ if __name__ == '__main__':
...
@@ -280,20 +282,21 @@ if __name__ == '__main__':
if
args
.
task
!=
'train'
:
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
assert
args
.
load
is
not
None
cfg
=
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
get_model_loader
(
args
.
load
),
session_init
=
get_model_loader
(
args
.
load
),
input_names
=
[
'state'
],
input_names
=
[
'state'
],
output_names
=
[
'policy'
])
output_names
=
[
'policy'
])
)
if
args
.
task
==
'play'
:
if
args
.
task
==
'play'
:
play_model
(
cfg
,
get_player
(
viz
=
0.01
))
play_n_episodes
(
get_player
(
train
=
False
),
pred
,
args
.
episode
,
render
=
True
)
elif
args
.
task
==
'eval'
:
elif
args
.
task
==
'eval'
:
eval_model_multithread
(
cfg
,
args
.
episode
,
get_player
)
eval_model_multithread
(
pred
,
args
.
episode
,
get_player
)
elif
args
.
task
==
'gen_submit'
:
elif
args
.
task
==
'gen_submit'
:
play_n_episodes
(
play_n_episodes
(
get_player
(
train
=
False
,
dumpdir
=
args
.
output
),
get_player
(
train
=
False
,
dumpdir
=
args
.
output
),
OfflinePredictor
(
cfg
)
,
args
.
episode
)
pred
,
args
.
episode
)
# gym.upload(output, api_key='xxx')
# gym.upload(
args.
output, api_key='xxx')
else
:
else
:
dirname
=
os
.
path
.
join
(
'train_log'
,
'train-atari-{}'
.
format
(
ENV_NAME
))
dirname
=
os
.
path
.
join
(
'train_log'
,
'train-atari-{}'
.
format
(
ENV_NAME
))
logger
.
set_logger_dir
(
dirname
)
logger
.
set_logger_dir
(
dirname
)
...
...
examples/DeepQNetwork/DQN.py
View file @
7e963996
...
@@ -18,14 +18,14 @@ from collections import deque
...
@@ -18,14 +18,14 @@ from collections import deque
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.RL
import
*
import
tensorflow
as
tf
import
tensorflow
as
tf
from
DQNModel
import
Model
as
DQNModel
from
DQNModel
import
Model
as
DQNModel
import
common
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
atari
import
AtariPlayer
from
common
import
FrameStack
,
WarpFrame
,
FireResetEnv
from
expreplay
import
ExpReplay
from
expreplay
import
ExpReplay
from
atari
import
AtariPlayer
BATCH_SIZE
=
64
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_SIZE
=
(
84
,
84
)
...
@@ -37,7 +37,7 @@ GAMMA = 0.99
...
@@ -37,7 +37,7 @@ GAMMA = 0.99
MEMORY_SIZE
=
1e6
MEMORY_SIZE
=
1e6
# will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
# will consume at least 1e6 * 84 * 84 bytes == 6.6G memory.
INIT_MEMORY_SIZE
=
5e4
INIT_MEMORY_SIZE
=
MEMORY_SIZE
//
20
STEPS_PER_EPOCH
=
10000
//
UPDATE_FREQ
*
10
# each epoch is 100k played frames
STEPS_PER_EPOCH
=
10000
//
UPDATE_FREQ
*
10
# each epoch is 100k played frames
EVAL_EPISODE
=
50
EVAL_EPISODE
=
50
...
@@ -47,17 +47,14 @@ METHOD = None
...
@@ -47,17 +47,14 @@ METHOD = None
def
get_player
(
viz
=
False
,
train
=
False
):
def
get_player
(
viz
=
False
,
train
=
False
):
pl
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
env
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
viz
=
viz
,
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
live_lost_as_eoe
=
train
)
live_lost_as_eoe
=
train
,
max_num_frames
=
30000
)
env
=
FireResetEnv
(
env
)
env
=
WarpFrame
(
env
,
IMAGE_SIZE
)
if
not
train
:
if
not
train
:
# create a new axis to stack history on
pl
=
MapPlayerState
(
pl
,
lambda
im
:
im
[:,
:,
np
.
newaxis
])
# in training, history is taken care of in expreplay buffer
# in training, history is taken care of in expreplay buffer
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
env
=
FrameStack
(
env
,
FRAME_HISTORY
)
return
env
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
LimitLengthPlayer
(
pl
,
30000
)
return
pl
class
Model
(
DQNModel
):
class
Model
(
DQNModel
):
...
@@ -149,20 +146,20 @@ if __name__ == '__main__':
...
@@ -149,20 +146,20 @@ if __name__ == '__main__':
ROM_FILE
=
args
.
rom
ROM_FILE
=
args
.
rom
METHOD
=
args
.
algo
METHOD
=
args
.
algo
# set num_actions
# set num_actions
NUM_ACTIONS
=
AtariPlayer
(
ROM_FILE
)
.
get_action_space
()
.
num_actions
()
NUM_ACTIONS
=
AtariPlayer
(
ROM_FILE
)
.
action_space
.
n
logger
.
info
(
"ROM: {}, Num Actions: {}"
.
format
(
ROM_FILE
,
NUM_ACTIONS
))
logger
.
info
(
"ROM: {}, Num Actions: {}"
.
format
(
ROM_FILE
,
NUM_ACTIONS
))
if
args
.
task
!=
'train'
:
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
assert
args
.
load
is
not
None
cfg
=
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
get_model_loader
(
args
.
load
),
session_init
=
get_model_loader
(
args
.
load
),
input_names
=
[
'state'
],
input_names
=
[
'state'
],
output_names
=
[
'Qvalue'
])
output_names
=
[
'Qvalue'
])
)
if
args
.
task
==
'play'
:
if
args
.
task
==
'play'
:
play_
model
(
cfg
,
get_player
(
viz
=
0.01
)
)
play_
n_episodes
(
get_player
(
viz
=
0.01
),
pred
,
100
)
elif
args
.
task
==
'eval'
:
elif
args
.
task
==
'eval'
:
eval_model_multithread
(
cfg
,
EVAL_EPISODE
,
get_player
)
eval_model_multithread
(
pred
,
EVAL_EPISODE
,
get_player
)
else
:
else
:
logger
.
set_logger_dir
(
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'DQN-{}'
.
format
(
os
.
path
.
join
(
'train_log'
,
'DQN-{}'
.
format
(
...
...
examples/DeepQNetwork/atari.py
View file @
7e963996
...
@@ -7,7 +7,6 @@ import numpy as np
...
@@ -7,7 +7,6 @@ import numpy as np
import
time
import
time
import
os
import
os
import
cv2
import
cv2
from
collections
import
deque
import
threading
import
threading
import
six
import
six
from
six.moves
import
range
from
six.moves
import
range
...
@@ -16,7 +15,9 @@ from tensorpack.utils.utils import get_rng, execute_only_once
...
@@ -16,7 +15,9 @@ from tensorpack.utils.utils import get_rng, execute_only_once
from
tensorpack.utils.fs
import
get_dataset_path
from
tensorpack.utils.fs
import
get_dataset_path
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.RL.envbase
import
RLEnvironment
,
DiscreteActionSpace
import
gym
from
gym
import
spaces
from
gym.envs.atari.atari_env
import
ACTION_MEANING
from
ale_python_interface
import
ALEInterface
from
ale_python_interface
import
ALEInterface
...
@@ -26,27 +27,29 @@ ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
...
@@ -26,27 +27,29 @@ ROM_URL = "https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK
=
threading
.
Lock
()
_ALE_LOCK
=
threading
.
Lock
()
class
AtariPlayer
(
RLEnvironment
):
class
AtariPlayer
(
gym
.
Env
):
"""
"""
A wrapper for atari emulator.
A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.
Will automatically restart when a real episode ends (isOver might be just
lost of lives but not game over).
Info:
score: the accumulated reward in the current game
gameOver: True when the current game is Over
"""
"""
def
__init__
(
self
,
rom_file
,
viz
=
0
,
height_range
=
(
None
,
None
),
def
__init__
(
self
,
rom_file
,
viz
=
0
,
frame_skip
=
4
,
image_shape
=
(
84
,
84
),
nullop_start
=
30
,
frame_skip
=
4
,
nullop_start
=
30
,
live_lost_as_eoe
=
True
):
live_lost_as_eoe
=
True
,
max_num_frames
=
0
):
"""
"""
:param rom_file: path to the rom
Args:
:param frame_skip: skip every k frames and repeat the action
rom_file: path to the rom
:param image_shape: (w, h)
frame_skip: skip every k frames and repeat the action
:param height_range: (h1, h2) to cut
viz: visualization to be done.
:param viz: visualization to be don
e.
Set to 0 to disabl
e.
Set to 0 to disable
.
Set to a positive number to be the delay between frames to show
.
Set to a positive number to be the delay between frames to show
.
Set to a string to be a directory to store frames
.
Set to a string to be a directory to store frame
s.
nullop_start: start with random number of null op
s.
:param nullop_start: start with random number of null ops
live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training
.
max_num_frames: maximum number of frames per episode
.
"""
"""
super
(
AtariPlayer
,
self
)
.
__init__
()
super
(
AtariPlayer
,
self
)
.
__init__
()
if
not
os
.
path
.
isfile
(
rom_file
)
and
'/'
not
in
rom_file
:
if
not
os
.
path
.
isfile
(
rom_file
)
and
'/'
not
in
rom_file
:
...
@@ -65,6 +68,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -65,6 +68,7 @@ class AtariPlayer(RLEnvironment):
self
.
ale
=
ALEInterface
()
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
b
"random_seed"
,
self
.
rng
.
randint
(
0
,
30000
))
self
.
ale
.
setInt
(
b
"random_seed"
,
self
.
rng
.
randint
(
0
,
30000
))
self
.
ale
.
setInt
(
b
"max_num_frames_per_episode"
,
max_num_frames
)
self
.
ale
.
setBool
(
b
"showinfo"
,
False
)
self
.
ale
.
setBool
(
b
"showinfo"
,
False
)
self
.
ale
.
setInt
(
b
"frame_skip"
,
1
)
self
.
ale
.
setInt
(
b
"frame_skip"
,
1
)
...
@@ -92,11 +96,16 @@ class AtariPlayer(RLEnvironment):
...
@@ -92,11 +96,16 @@ class AtariPlayer(RLEnvironment):
self
.
live_lost_as_eoe
=
live_lost_as_eoe
self
.
live_lost_as_eoe
=
live_lost_as_eoe
self
.
frame_skip
=
frame_skip
self
.
frame_skip
=
frame_skip
self
.
nullop_start
=
nullop_start
self
.
nullop_start
=
nullop_start
self
.
height_range
=
height_range
self
.
image_shape
=
image_shape
self
.
current_episode_score
=
StatCounter
()
self
.
current_episode_score
=
StatCounter
()
self
.
restart_episode
()
self
.
action_space
=
spaces
.
Discrete
(
len
(
self
.
actions
))
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
self
.
height
,
self
.
width
))
self
.
_restart_episode
()
def
get_action_meanings
(
self
):
return
[
ACTION_MEANING
[
i
]
for
i
in
self
.
actions
]
def
_grab_raw_image
(
self
):
def
_grab_raw_image
(
self
):
"""
"""
...
@@ -105,7 +114,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -105,7 +114,7 @@ class AtariPlayer(RLEnvironment):
m
=
self
.
ale
.
getScreenRGB
()
m
=
self
.
ale
.
getScreenRGB
()
return
m
.
reshape
((
self
.
height
,
self
.
width
,
3
))
return
m
.
reshape
((
self
.
height
,
self
.
width
,
3
))
def
current_state
(
self
):
def
_
current_state
(
self
):
"""
"""
:returns: a gray-scale (h, w) uint8 image
:returns: a gray-scale (h, w) uint8 image
"""
"""
...
@@ -116,19 +125,12 @@ class AtariPlayer(RLEnvironment):
...
@@ -116,19 +125,12 @@ class AtariPlayer(RLEnvironment):
if
isinstance
(
self
.
viz
,
float
):
if
isinstance
(
self
.
viz
,
float
):
cv2
.
imshow
(
self
.
windowname
,
ret
)
cv2
.
imshow
(
self
.
windowname
,
ret
)
time
.
sleep
(
self
.
viz
)
time
.
sleep
(
self
.
viz
)
ret
=
ret
[
self
.
height_range
[
0
]:
self
.
height_range
[
1
],
:]
.
astype
(
'float32'
)
ret
=
ret
.
astype
(
'float32'
)
# 0.299,0.587.0.114. same as rgb2y in torch/image
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)
ret
=
cv2
.
resize
(
ret
,
self
.
image_shape
)
return
ret
.
astype
(
'uint8'
)
# to save some memory
return
ret
.
astype
(
'uint8'
)
# to save some memory
def
get_action_space
(
self
):
def
_restart_episode
(
self
):
return
DiscreteActionSpace
(
len
(
self
.
actions
))
def
finish_episode
(
self
):
self
.
stats
[
'score'
]
.
append
(
self
.
current_episode_score
.
sum
)
def
restart_episode
(
self
):
self
.
current_episode_score
.
reset
()
self
.
current_episode_score
.
reset
()
with
_ALE_LOCK
:
with
_ALE_LOCK
:
self
.
ale
.
reset_game
()
self
.
ale
.
reset_game
()
...
@@ -141,11 +143,12 @@ class AtariPlayer(RLEnvironment):
...
@@ -141,11 +143,12 @@ class AtariPlayer(RLEnvironment):
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
self
.
ale
.
act
(
0
)
self
.
ale
.
act
(
0
)
def
action
(
self
,
act
):
def
_reset
(
self
):
"""
if
self
.
ale
.
game_over
():
:param act: an index of the action
self
.
_restart_episode
()
:returns: (reward, isOver)
return
self
.
_current_state
()
"""
def
_step
(
self
,
act
):
oldlives
=
self
.
ale
.
lives
()
oldlives
=
self
.
ale
.
lives
()
r
=
0
r
=
0
for
k
in
range
(
self
.
frame_skip
):
for
k
in
range
(
self
.
frame_skip
):
...
@@ -158,55 +161,24 @@ class AtariPlayer(RLEnvironment):
...
@@ -158,55 +161,24 @@ class AtariPlayer(RLEnvironment):
break
break
self
.
current_episode_score
.
feed
(
r
)
self
.
current_episode_score
.
feed
(
r
)
isOver
=
self
.
ale
.
game_over
()
trueIsOver
=
isOver
=
self
.
ale
.
game_over
()
if
self
.
live_lost_as_eoe
:
if
self
.
live_lost_as_eoe
:
isOver
=
isOver
or
newlives
<
oldlives
isOver
=
isOver
or
newlives
<
oldlives
if
isOver
:
self
.
finish_episode
()
info
=
{
'score'
:
self
.
current_episode_score
.
sum
,
'gameOver'
:
trueIsOver
}
if
self
.
ale
.
game_over
():
return
self
.
_current_state
(),
r
,
isOver
,
info
self
.
restart_episode
()
return
(
r
,
isOver
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
sys
import
sys
def
benchmark
():
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0.03
)
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
False
,
height_range
=
(
28
,
-
8
))
num
=
a
.
action_space
.
n
num
=
a
.
get_action_space
()
.
num_actions
()
rng
=
get_rng
(
num
)
rng
=
get_rng
(
num
)
while
True
:
start
=
time
.
time
()
act
=
rng
.
choice
(
range
(
num
))
cnt
=
0
state
,
reward
,
isOver
,
info
=
a
.
step
(
act
)
while
True
:
if
isOver
:
act
=
rng
.
choice
(
range
(
num
))
print
(
info
)
r
,
o
=
a
.
action
(
act
)
a
.
reset
()
a
.
current_state
()
print
(
"Reward:"
,
reward
)
cnt
+=
1
if
cnt
==
5000
:
break
print
(
time
.
time
()
-
start
)
if
len
(
sys
.
argv
)
==
3
and
sys
.
argv
[
2
]
==
'benchmark'
:
import
threading
import
multiprocessing
for
k
in
range
(
3
):
# th = multiprocessing.Process(target=benchmark)
th
=
threading
.
Thread
(
target
=
benchmark
)
th
.
start
()
time
.
sleep
(
0.02
)
benchmark
()
else
:
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0.03
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_action_space
()
.
num_actions
()
rng
=
get_rng
(
num
)
import
time
while
True
:
# im = a.grab_image()
# cv2.imshow(a.romname, im)
act
=
rng
.
choice
(
range
(
num
))
print
(
act
)
r
,
o
=
a
.
action
(
act
)
a
.
current_state
()
# time.sleep(0.1)
print
(
r
,
o
)
examples/DeepQNetwork/common.py
View file @
7e963996
...
@@ -7,35 +7,56 @@ import time
...
@@ -7,35 +7,56 @@ import time
import
threading
import
threading
import
multiprocessing
import
multiprocessing
import
numpy
as
np
import
numpy
as
np
import
cv2
from
collections
import
deque
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
six.moves
import
queue
from
six.moves
import
queue
from
tensorpack
import
*
import
gym
from
tensorpack.utils.concurrency
import
*
from
gym
import
spaces
from
tensorpack.utils.stats
import
*
from
tensorpack.utils.concurrency
import
StoppableThread
,
ShareSessionThread
from
tensorpack.callbacks
import
Triggerable
from
tensorpack.utils
import
logger
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.utils
import
get_tqdm_kwargs
from
tensorpack.utils.utils
import
get_tqdm_kwargs
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
def
play_one_episode
(
env
,
func
,
render
=
False
):
def
f
(
s
):
def
predict
(
s
):
spc
=
player
.
get_action_space
()
"""
Map from observation to action, with 0.001 greedy.
"""
act
=
func
([[
s
]])[
0
][
0
]
.
argmax
()
act
=
func
([[
s
]])[
0
][
0
]
.
argmax
()
if
random
.
random
()
<
0.001
:
if
random
.
random
()
<
0.001
:
spc
=
env
.
action_space
act
=
spc
.
sample
()
act
=
spc
.
sample
()
if
verbose
:
print
(
act
)
return
act
return
act
return
np
.
mean
(
player
.
play_one_episode
(
f
))
def
play_model
(
cfg
,
player
):
ob
=
env
.
reset
()
predfunc
=
OfflinePredictor
(
cfg
)
sum_r
=
0
while
True
:
while
True
:
score
=
play_one_episode
(
player
,
predfunc
)
act
=
predict
(
ob
)
print
(
"Total:"
,
score
)
ob
,
r
,
isOver
,
info
=
env
.
step
(
act
)
if
render
:
env
.
render
()
sum_r
+=
r
if
isOver
:
return
sum_r
def
play_n_episodes
(
player
,
predfunc
,
nr
,
render
=
False
):
logger
.
info
(
"Start Playing ... "
)
for
k
in
range
(
nr
):
score
=
play_one_episode
(
player
,
predfunc
,
render
=
render
)
print
(
"{}/{}, score={}"
.
format
(
k
,
nr
,
score
))
def
eval_with_funcs
(
predictors
,
nr_eval
,
get_player_fn
):
def
eval_with_funcs
(
predictors
,
nr_eval
,
get_player_fn
):
"""
Args:
predictors ([PredictorBase])
"""
class
Worker
(
StoppableThread
,
ShareSessionThread
):
class
Worker
(
StoppableThread
,
ShareSessionThread
):
def
__init__
(
self
,
func
,
queue
):
def
__init__
(
self
,
func
,
queue
):
super
(
Worker
,
self
)
.
__init__
()
super
(
Worker
,
self
)
.
__init__
()
...
@@ -85,10 +106,14 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
...
@@ -85,10 +106,14 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
return
(
0
,
0
)
return
(
0
,
0
)
def
eval_model_multithread
(
cfg
,
nr_eval
,
get_player_fn
):
def
eval_model_multithread
(
pred
,
nr_eval
,
get_player_fn
):
func
=
OfflinePredictor
(
cfg
)
"""
Args:
pred (OfflinePredictor): state -> Qvalue
"""
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
mean
,
max
=
eval_with_funcs
([
func
]
*
NR_PROC
,
nr_eval
,
get_player_fn
)
with
pred
.
sess
.
as_default
():
mean
,
max
=
eval_with_funcs
([
pred
]
*
NR_PROC
,
nr_eval
,
get_player_fn
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
...
@@ -115,10 +140,103 @@ class Evaluator(Triggerable):
...
@@ -115,10 +140,103 @@ class Evaluator(Triggerable):
self
.
trainer
.
monitors
.
put_scalar
(
'max_score'
,
max
)
self
.
trainer
.
monitors
.
put_scalar
(
'max_score'
,
max
)
def
play_n_episodes
(
player
,
predfunc
,
nr
):
"""
logger
.
info
(
"Start evaluation: "
)
------------------------------------------------------------------------------
for
k
in
range
(
nr
):
The following wrappers are copied or modified from openai/baselines:
if
k
!=
0
:
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
player
.
restart_episode
()
"""
score
=
play_one_episode
(
player
,
predfunc
)
print
(
"{}/{}, score={}"
.
format
(
k
,
nr
,
score
))
class
WarpFrame
(
gym
.
ObservationWrapper
):
def
__init__
(
self
,
env
,
shape
):
gym
.
ObservationWrapper
.
__init__
(
self
,
env
)
self
.
shape
=
shape
obs
=
env
.
observation_space
assert
isinstance
(
obs
,
spaces
.
Box
)
chan
=
1
if
len
(
obs
.
shape
)
==
2
else
obs
.
shape
[
2
]
shape3d
=
shape
if
chan
==
1
else
shape
+
(
chan
,)
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
shape3d
)
def
_observation
(
self
,
obs
):
return
cv2
.
resize
(
obs
,
self
.
shape
)
class
FrameStack
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
"""Buffer observations and stack across channels (last axis)."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
self
.
frames
=
deque
([],
maxlen
=
k
)
shp
=
env
.
observation_space
.
shape
chan
=
1
if
len
(
shp
)
==
2
else
shp
[
2
]
self
.
_base_dim
=
len
(
shp
)
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
shp
[
0
],
shp
[
1
],
chan
*
k
))
def
_reset
(
self
):
"""Clear buffer and re-fill by duplicating the first observation."""
ob
=
self
.
env
.
reset
()
for
_
in
range
(
self
.
k
-
1
):
self
.
frames
.
append
(
np
.
zeros_like
(
ob
))
self
.
frames
.
append
(
ob
)
return
self
.
_observation
()
def
_step
(
self
,
action
):
ob
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
frames
.
append
(
ob
)
return
self
.
_observation
(),
reward
,
done
,
info
def
_observation
(
self
):
assert
len
(
self
.
frames
)
==
self
.
k
if
self
.
_base_dim
==
2
:
return
np
.
stack
(
self
.
frames
,
axis
=-
1
)
else
:
return
np
.
concatenate
(
self
.
frames
,
axis
=
2
)
class
_FireResetEnv
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
):
"""Take action on reset for environments that are fixed until firing."""
gym
.
Wrapper
.
__init__
(
self
,
env
)
assert
env
.
unwrapped
.
get_action_meanings
()[
1
]
==
'FIRE'
assert
len
(
env
.
unwrapped
.
get_action_meanings
())
>=
3
def
_reset
(
self
):
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
1
)
if
done
:
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
2
)
if
done
:
self
.
env
.
reset
()
return
obs
def
FireResetEnv
(
env
):
if
isinstance
(
env
,
gym
.
Wrapper
):
baseenv
=
env
.
unwrapped
else
:
baseenv
=
env
if
'FIRE'
in
baseenv
.
get_action_meanings
():
return
_FireResetEnv
(
env
)
return
env
class
LimitLength
(
gym
.
Wrapper
):
def
__init__
(
self
,
env
,
k
):
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
def
_reset
(
self
):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
ob
=
self
.
env
.
reset
()
self
.
cnt
=
0
return
ob
def
_step
(
self
,
action
):
ob
,
r
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
cnt
+=
1
if
self
.
cnt
==
self
.
k
:
done
=
True
return
ob
,
r
,
done
,
info
examples/DeepQNetwork/expreplay.py
View file @
7e963996
...
@@ -13,6 +13,7 @@ from six.moves import queue, range
...
@@ -13,6 +13,7 @@ from six.moves import queue, range
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
from
tensorpack.utils.utils
import
get_tqdm
,
get_rng
from
tensorpack.utils.utils
import
get_tqdm
,
get_rng
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.concurrency
import
LoopThread
,
ShareSessionThread
from
tensorpack.utils.concurrency
import
LoopThread
,
ShareSessionThread
from
tensorpack.callbacks.base
import
Callback
from
tensorpack.callbacks.base
import
Callback
...
@@ -142,7 +143,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -142,7 +143,7 @@ class ExpReplay(DataFlow, Callback):
if
k
!=
'self'
:
if
k
!=
'self'
:
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
self
.
exploration
=
init_exploration
self
.
exploration
=
init_exploration
self
.
num_actions
=
player
.
get_action_space
()
.
num_actions
()
self
.
num_actions
=
player
.
action_space
.
n
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
logger
.
info
(
"Number of Legal actions: {}"
.
format
(
self
.
num_actions
))
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
...
@@ -152,6 +153,8 @@ class ExpReplay(DataFlow, Callback):
...
@@ -152,6 +153,8 @@ class ExpReplay(DataFlow, Callback):
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
5
)
self
.
_populate_job_queue
=
queue
.
Queue
(
maxsize
=
5
)
self
.
mem
=
ReplayMemory
(
memory_size
,
state_shape
,
history_len
)
self
.
mem
=
ReplayMemory
(
memory_size
,
state_shape
,
history_len
)
self
.
_current_ob
=
self
.
player
.
reset
()
self
.
_player_scores
=
StatCounter
()
def
get_simulator_thread
(
self
):
def
get_simulator_thread
(
self
):
# spawn a separate thread to run policy
# spawn a separate thread to run policy
...
@@ -186,7 +189,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -186,7 +189,7 @@ class ExpReplay(DataFlow, Callback):
def
_populate_exp
(
self
):
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
""" populate a transition by epsilon-greedy"""
old_s
=
self
.
player
.
current_state
()
old_s
=
self
.
_current_ob
if
self
.
rng
.
rand
()
<=
self
.
exploration
or
(
len
(
self
.
mem
)
<=
self
.
history_len
):
if
self
.
rng
.
rand
()
<=
self
.
exploration
or
(
len
(
self
.
mem
)
<=
self
.
history_len
):
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
act
=
self
.
rng
.
choice
(
range
(
self
.
num_actions
))
else
:
else
:
...
@@ -198,7 +201,11 @@ class ExpReplay(DataFlow, Callback):
...
@@ -198,7 +201,11 @@ class ExpReplay(DataFlow, Callback):
# assume batched network
# assume batched network
q_values
=
self
.
predictor
([[
history
]])[
0
][
0
]
# this is the bottleneck
q_values
=
self
.
predictor
([[
history
]])[
0
][
0
]
# this is the bottleneck
act
=
np
.
argmax
(
q_values
)
act
=
np
.
argmax
(
q_values
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
if
isOver
:
if
info
[
'gameOver'
]:
# only record score when a whole game is over (not when an episode is over)
self
.
_player_scores
.
feed
(
info
[
'score'
])
self
.
player
.
reset
()
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
def
_debug_sample
(
self
,
sample
):
def
_debug_sample
(
self
,
sample
):
...
@@ -245,17 +252,15 @@ class ExpReplay(DataFlow, Callback):
...
@@ -245,17 +252,15 @@ class ExpReplay(DataFlow, Callback):
self
.
_simulator_th
=
self
.
get_simulator_thread
()
self
.
_simulator_th
=
self
.
get_simulator_thread
()
self
.
_simulator_th
.
start
()
self
.
_simulator_th
.
start
()
def
_trigger_epoch
(
self
):
def
_trigger
(
self
):
# log player statistics in training
v
=
self
.
_player_scores
stats
=
self
.
player
.
stats
try
:
for
k
,
v
in
six
.
iteritems
(
stats
):
mean
,
max
=
v
.
average
,
v
.
max
try
:
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/mean_score'
,
mean
)
mean
,
max
=
np
.
mean
(
v
),
np
.
max
(
v
)
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/max_score'
,
max
)
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/mean_'
+
k
,
mean
)
except
:
self
.
trainer
.
monitors
.
put_scalar
(
'expreplay/max_'
+
k
,
max
)
logger
.
exception
(
"Cannot log training scores."
)
except
:
v
.
reset
()
logger
.
exception
(
"Cannot log training scores."
)
self
.
player
.
reset_stat
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/RL/__init__.py
View file @
7e963996
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
pkgutil
import
iter_modules
from
pkgutil
import
iter_modules
from
..utils.develop
import
log_deprecated
import
os
import
os
import
os.path
import
os.path
...
@@ -13,6 +14,8 @@ __all__ = []
...
@@ -13,6 +14,8 @@ __all__ = []
This module should be removed in the future.
This module should be removed in the future.
"""
"""
log_deprecated
(
"tensorpack.RL"
,
"Please use gym or other APIs instead!"
,
"2017-12-31"
)
def
_global_import
(
name
):
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment