Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0b561b3b
Commit
0b561b3b
authored
Mar 05, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Make DQN support states with more dimensions
parent
87fad54b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
92 additions
and
61 deletions
+92
-61
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+7
-4
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+15
-6
examples/DeepQNetwork/DQNModel.py
examples/DeepQNetwork/DQNModel.py
+29
-20
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+3
-3
examples/DeepQNetwork/atari_wrapper.py
examples/DeepQNetwork/atari_wrapper.py
+3
-11
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+4
-1
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+20
-14
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+7
-1
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+3
-1
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+1
-0
No files found.
examples/A3C-Gym/train-atari.py
View file @
0b561b3b
...
@@ -34,8 +34,7 @@ else:
...
@@ -34,8 +34,7 @@ else:
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_SIZE
=
(
84
,
84
)
FRAME_HISTORY
=
4
FRAME_HISTORY
=
4
GAMMA
=
0.99
GAMMA
=
0.99
CHANNEL
=
FRAME_HISTORY
*
3
STATE_SHAPE
=
IMAGE_SIZE
+
(
3
,
)
IMAGE_SHAPE3
=
IMAGE_SIZE
+
(
CHANNEL
,)
LOCAL_TIME_MAX
=
5
LOCAL_TIME_MAX
=
5
STEPS_PER_EPOCH
=
6000
STEPS_PER_EPOCH
=
6000
...
@@ -70,13 +69,17 @@ class MySimulatorWorker(SimulatorProcess):
...
@@ -70,13 +69,17 @@ class MySimulatorWorker(SimulatorProcess):
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
def
inputs
(
self
):
def
inputs
(
self
):
assert
NUM_ACTIONS
is
not
None
assert
NUM_ACTIONS
is
not
None
return
[
tf
.
placeholder
(
tf
.
uint8
,
(
None
,)
+
IMAGE_SHAPE3
,
'state'
),
return
[
tf
.
placeholder
(
tf
.
uint8
,
(
None
,)
+
STATE_SHAPE
+
(
FRAME_HISTORY
,
)
,
'state'
),
tf
.
placeholder
(
tf
.
int64
,
(
None
,),
'action'
),
tf
.
placeholder
(
tf
.
int64
,
(
None
,),
'action'
),
tf
.
placeholder
(
tf
.
float32
,
(
None
,),
'futurereward'
),
tf
.
placeholder
(
tf
.
float32
,
(
None
,),
'futurereward'
),
tf
.
placeholder
(
tf
.
float32
,
(
None
,),
'action_prob'
),
tf
.
placeholder
(
tf
.
float32
,
(
None
,),
'action_prob'
),
]
]
def
_get_NN_prediction
(
self
,
image
):
def
_get_NN_prediction
(
self
,
state
):
assert
state
.
shape
.
rank
==
5
# Batch, H, W, Channel, History
state
=
tf
.
transpose
(
state
,
[
0
,
1
,
2
,
4
,
3
])
# swap channel & history, to be compatible with old models
image
=
tf
.
reshape
(
state
,
[
-
1
]
+
list
(
STATE_SHAPE
[:
2
])
+
[
STATE_SHAPE
[
2
]
*
FRAME_HISTORY
])
image
=
tf
.
cast
(
image
,
tf
.
float32
)
/
255.0
image
=
tf
.
cast
(
image
,
tf
.
float32
)
/
255.0
with
argscope
(
Conv2D
,
activation
=
tf
.
nn
.
relu
):
with
argscope
(
Conv2D
,
activation
=
tf
.
nn
.
relu
):
l
=
Conv2D
(
'conv0'
,
image
,
32
,
5
)
l
=
Conv2D
(
'conv0'
,
image
,
32
,
5
)
...
...
examples/DeepQNetwork/DQN.py
View file @
0b561b3b
...
@@ -19,7 +19,7 @@ from expreplay import ExpReplay
...
@@ -19,7 +19,7 @@ from expreplay import ExpReplay
BATCH_SIZE
=
64
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_CHANNEL
=
None
# 3 in gym and 1 in our own wrapper
STATE_SHAPE
=
None
# IMAGE_SIZE + (3,) in gym, and IMAGE_SIZE in ALE
FRAME_HISTORY
=
4
FRAME_HISTORY
=
4
ACTION_REPEAT
=
4
# aka FRAME_SKIP
ACTION_REPEAT
=
4
# aka FRAME_SKIP
UPDATE_FREQ
=
4
UPDATE_FREQ
=
4
...
@@ -39,8 +39,7 @@ METHOD = None
...
@@ -39,8 +39,7 @@ METHOD = None
def
resize_keepdims
(
im
,
size
):
def
resize_keepdims
(
im
,
size
):
# Opencv's resize remove the extra dimension for grayscale images.
# Opencv's resize remove the extra dimension for grayscale images. We add it back.
# We add it back.
ret
=
cv2
.
resize
(
im
,
size
)
ret
=
cv2
.
resize
(
im
,
size
)
if
im
.
ndim
==
3
and
ret
.
ndim
==
2
:
if
im
.
ndim
==
3
and
ret
.
ndim
==
2
:
ret
=
ret
[:,
:,
np
.
newaxis
]
ret
=
ret
[:,
:,
np
.
newaxis
]
...
@@ -65,10 +64,20 @@ def get_player(viz=False, train=False):
...
@@ -65,10 +64,20 @@ def get_player(viz=False, train=False):
class
Model
(
DQNModel
):
class
Model
(
DQNModel
):
"""
A DQN model for 2D/3D (image) observations.
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
)
.
__init__
(
IMAGE_SIZE
,
IMAGE_CHANNEL
,
FRAME_HISTORY
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
assert
len
(
STATE_SHAPE
)
in
[
2
,
3
]
super
(
Model
,
self
)
.
__init__
(
STATE_SHAPE
,
FRAME_HISTORY
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
def
_get_DQN_prediction
(
self
,
image
):
def
_get_DQN_prediction
(
self
,
image
):
assert
image
.
shape
.
rank
in
[
4
,
5
],
image
.
shape
# image: N, H, W, (C), Hist
if
image
.
shape
.
rank
==
5
:
# merge C & Hist
image
=
tf
.
reshape
(
image
,
[
-
1
]
+
list
(
STATE_SHAPE
[:
2
])
+
[
STATE_SHAPE
[
2
]
*
FRAME_HISTORY
])
image
=
image
/
255.0
image
=
image
/
255.0
with
argscope
(
Conv2D
,
activation
=
lambda
x
:
PReLU
(
'prelu'
,
x
),
use_bias
=
True
):
with
argscope
(
Conv2D
,
activation
=
lambda
x
:
PReLU
(
'prelu'
,
x
),
use_bias
=
True
):
l
=
(
LinearWrap
(
image
)
l
=
(
LinearWrap
(
image
)
...
@@ -102,7 +111,7 @@ def get_config():
...
@@ -102,7 +111,7 @@ def get_config():
expreplay
=
ExpReplay
(
expreplay
=
ExpReplay
(
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
predictor_io_names
=
([
'state'
],
[
'Qvalue'
]),
player
=
get_player
(
train
=
True
),
player
=
get_player
(
train
=
True
),
state_shape
=
IMAGE_SIZE
+
(
IMAGE_CHANNEL
,)
,
state_shape
=
STATE_SHAPE
,
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
memory_size
=
MEMORY_SIZE
,
memory_size
=
MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
init_memory_size
=
INIT_MEMORY_SIZE
,
...
@@ -152,7 +161,7 @@ if __name__ == '__main__':
...
@@ -152,7 +161,7 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
ENV_NAME
=
args
.
env
ENV_NAME
=
args
.
env
USE_GYM
=
not
ENV_NAME
.
endswith
(
'.bin'
)
USE_GYM
=
not
ENV_NAME
.
endswith
(
'.bin'
)
IMAGE_CHANNEL
=
3
if
USE_GYM
else
1
STATE_SHAPE
=
IMAGE_SIZE
+
(
3
,
)
if
USE_GYM
else
IMAGE_SIZE
METHOD
=
args
.
algo
METHOD
=
args
.
algo
# set num_actions
# set num_actions
NUM_ACTIONS
=
get_player
()
.
action_space
.
n
NUM_ACTIONS
=
get_player
()
.
action_space
.
n
...
...
examples/DeepQNetwork/DQNModel.py
View file @
0b561b3b
...
@@ -12,16 +12,19 @@ from tensorpack.utils import logger
...
@@ -12,16 +12,19 @@ from tensorpack.utils import logger
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
learning_rate
=
1e-3
learning_rate
=
1e-3
def
__init__
(
self
,
image_shape
,
channel
,
history
,
method
,
num_actions
,
gamma
):
state_dtype
=
tf
.
uint8
assert
len
(
image_shape
)
==
2
,
image_shape
self
.
channel
=
channel
def
__init__
(
self
,
state_shape
,
history
,
method
,
num_actions
,
gamma
):
self
.
_shape2d
=
tuple
(
image_shape
)
"""
self
.
_shape3d
=
self
.
_shape2d
+
(
channel
,
)
Args:
self
.
_shape4d_for_prediction
=
(
-
1
,
)
+
self
.
_shape2d
+
(
history
*
channel
,
)
state_shape (tuple[int]),
self
.
_channel
=
channel
history (int):
"""
self
.
_state_shape
=
tuple
(
state_shape
)
self
.
_stacked_state_shape
=
(
-
1
,
)
+
self
.
_state_shape
+
(
history
,
)
self
.
history
=
history
self
.
history
=
history
self
.
method
=
method
self
.
method
=
method
self
.
num_actions
=
num_actions
self
.
num_actions
=
num_actions
...
@@ -31,37 +34,43 @@ class Model(ModelDesc):
...
@@ -31,37 +34,43 @@ class Model(ModelDesc):
# When we use h history frames, the current state and the next state will have (h-1) overlapping frames.
# When we use h history frames, the current state and the next state will have (h-1) overlapping frames.
# Therefore we use a combined state for efficiency:
# Therefore we use a combined state for efficiency:
# The first h are the current state, and the last h are the next state.
# The first h are the current state, and the last h are the next state.
return
[
tf
.
placeholder
(
tf
.
uint8
,
return
[
tf
.
placeholder
(
self
.
state_dtype
,
(
None
,)
+
self
.
_shape2d
+
(
None
,)
+
self
.
_state_shape
+
(
self
.
history
+
1
,
),
((
self
.
history
+
1
)
*
self
.
channel
,),
'comb_state'
),
'comb_state'
),
tf
.
placeholder
(
tf
.
int64
,
(
None
,),
'action'
),
tf
.
placeholder
(
tf
.
int64
,
(
None
,),
'action'
),
tf
.
placeholder
(
tf
.
float32
,
(
None
,),
'reward'
),
tf
.
placeholder
(
tf
.
float32
,
(
None
,),
'reward'
),
tf
.
placeholder
(
tf
.
bool
,
(
None
,),
'isOver'
)]
tf
.
placeholder
(
tf
.
bool
,
(
None
,),
'isOver'
)]
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
_get_DQN_prediction
(
self
,
image
):
def
_get_DQN_prediction
(
self
,
state
):
"""
state: N + state_shape + history
"""
pass
pass
@
auto_reuse_variable_scope
@
auto_reuse_variable_scope
def
get_DQN_prediction
(
self
,
image
):
def
get_DQN_prediction
(
self
,
state
):
""" image: [N, H, W, history * C] in [0,255]"""
return
self
.
_get_DQN_prediction
(
state
)
return
self
.
_get_DQN_prediction
(
image
)
def
build_graph
(
self
,
comb_state
,
action
,
reward
,
isOver
):
def
build_graph
(
self
,
comb_state
,
action
,
reward
,
isOver
):
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
comb_state
=
tf
.
reshape
(
input_rank
=
comb_state
.
shape
.
rank
comb_state
,
[
-
1
]
+
list
(
self
.
_shape2d
)
+
[
self
.
history
+
1
,
self
.
channel
])
state
=
tf
.
slice
(
comb_state
,
[
0
]
*
input_rank
,
[
-
1
]
*
(
input_rank
-
1
)
+
[
self
.
history
],
name
=
'state'
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
self
.
history
,
-
1
])
state
=
tf
.
reshape
(
state
,
self
.
_shape4d_for_prediction
,
name
=
'state'
)
self
.
predict_value
=
self
.
get_DQN_prediction
(
state
)
self
.
predict_value
=
self
.
get_DQN_prediction
(
state
)
if
not
get_current_tower_context
()
.
is_training
:
if
not
get_current_tower_context
()
.
is_training
:
return
return
reward
=
tf
.
clip_by_value
(
reward
,
-
1
,
1
)
reward
=
tf
.
clip_by_value
(
reward
,
-
1
,
1
)
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
,
0
],
[
-
1
,
-
1
,
-
1
,
self
.
history
,
-
1
],
name
=
'next_state'
)
next_state
=
tf
.
slice
(
next_state
=
tf
.
reshape
(
next_state
,
self
.
_shape4d_for_prediction
)
comb_state
,
[
0
]
*
(
input_rank
-
1
)
+
[
1
],
[
-
1
]
*
(
input_rank
-
1
)
+
[
self
.
history
],
name
=
'next_state'
)
next_state
=
tf
.
reshape
(
next_state
,
self
.
_stacked_state_shape
)
action_onehot
=
tf
.
one_hot
(
action
,
self
.
num_actions
,
1.0
,
0.0
)
action_onehot
=
tf
.
one_hot
(
action
,
self
.
num_actions
,
1.0
,
0.0
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
...
...
examples/DeepQNetwork/atari.py
View file @
0b561b3b
...
@@ -94,7 +94,7 @@ class AtariPlayer(gym.Env):
...
@@ -94,7 +94,7 @@ class AtariPlayer(gym.Env):
self
.
action_space
=
spaces
.
Discrete
(
len
(
self
.
actions
))
self
.
action_space
=
spaces
.
Discrete
(
len
(
self
.
actions
))
self
.
observation_space
=
spaces
.
Box
(
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
self
.
height
,
self
.
width
,
1
),
dtype
=
np
.
uint8
)
low
=
0
,
high
=
255
,
shape
=
(
self
.
height
,
self
.
width
),
dtype
=
np
.
uint8
)
self
.
_restart_episode
()
self
.
_restart_episode
()
def
get_action_meanings
(
self
):
def
get_action_meanings
(
self
):
...
@@ -109,7 +109,7 @@ class AtariPlayer(gym.Env):
...
@@ -109,7 +109,7 @@ class AtariPlayer(gym.Env):
def
_current_state
(
self
):
def
_current_state
(
self
):
"""
"""
:returns: a gray-scale (h, w
, 1
) uint8 image
:returns: a gray-scale (h, w) uint8 image
"""
"""
ret
=
self
.
_grab_raw_image
()
ret
=
self
.
_grab_raw_image
()
# max-pooled over the last screen
# max-pooled over the last screen
...
@@ -120,7 +120,7 @@ class AtariPlayer(gym.Env):
...
@@ -120,7 +120,7 @@ class AtariPlayer(gym.Env):
cv2
.
waitKey
(
int
(
self
.
viz
*
1000
))
cv2
.
waitKey
(
int
(
self
.
viz
*
1000
))
ret
=
ret
.
astype
(
'float32'
)
ret
=
ret
.
astype
(
'float32'
)
# 0.299,0.587.0.114. same as rgb2y in torch/image
# 0.299,0.587.0.114. same as rgb2y in torch/image
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)[:,
:
,
np
.
newaxis
]
ret
=
cv2
.
cvtColor
(
ret
,
cv2
.
COLOR_RGB2GRAY
)[:,
:]
return
ret
.
astype
(
'uint8'
)
# to save some memory
return
ret
.
astype
(
'uint8'
)
# to save some memory
def
_restart_episode
(
self
):
def
_restart_episode
(
self
):
...
...
examples/DeepQNetwork/atari_wrapper.py
View file @
0b561b3b
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
import
numpy
as
np
import
numpy
as
np
from
collections
import
deque
from
collections
import
deque
import
gym
import
gym
from
gym
import
spaces
_v0
,
_v1
=
gym
.
__version__
.
split
(
'.'
)[:
2
]
_v0
,
_v1
=
gym
.
__version__
.
split
(
'.'
)[:
2
]
assert
int
(
_v0
)
>
0
or
int
(
_v1
)
>=
10
,
gym
.
__version__
assert
int
(
_v0
)
>
0
or
int
(
_v1
)
>=
10
,
gym
.
__version__
...
@@ -27,17 +26,13 @@ class MapState(gym.ObservationWrapper):
...
@@ -27,17 +26,13 @@ class MapState(gym.ObservationWrapper):
class
FrameStack
(
gym
.
Wrapper
):
class
FrameStack
(
gym
.
Wrapper
):
"""
"""
Buffer
observations and stack across channels (last axis)
.
Buffer
consecutive k observations and stack them on a new last axis
.
The output observation has shape
(H, W, History * Channel)
The output observation has shape
`original_shape + (k, )`.
"""
"""
def
__init__
(
self
,
env
,
k
):
def
__init__
(
self
,
env
,
k
):
gym
.
Wrapper
.
__init__
(
self
,
env
)
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
self
.
k
=
k
self
.
frames
=
deque
([],
maxlen
=
k
)
self
.
frames
=
deque
([],
maxlen
=
k
)
shp
=
env
.
observation_space
.
shape
chan
=
1
if
len
(
shp
)
==
2
else
shp
[
2
]
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
shp
[
0
],
shp
[
1
],
chan
*
k
),
dtype
=
np
.
uint8
)
def
reset
(
self
):
def
reset
(
self
):
"""Clear buffer and re-fill by duplicating the first observation."""
"""Clear buffer and re-fill by duplicating the first observation."""
...
@@ -54,10 +49,7 @@ class FrameStack(gym.Wrapper):
...
@@ -54,10 +49,7 @@ class FrameStack(gym.Wrapper):
def
observation
(
self
):
def
observation
(
self
):
assert
len
(
self
.
frames
)
==
self
.
k
assert
len
(
self
.
frames
)
==
self
.
k
if
self
.
frames
[
-
1
]
.
ndim
==
2
:
return
np
.
stack
(
self
.
frames
,
axis
=-
1
)
return
np
.
stack
(
self
.
frames
,
axis
=-
1
)
else
:
return
np
.
concatenate
(
self
.
frames
,
axis
=
2
)
class
_FireResetEnv
(
gym
.
Wrapper
):
class
_FireResetEnv
(
gym
.
Wrapper
):
...
...
examples/DeepQNetwork/common.py
View file @
0b561b3b
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: common.py
# File: common.py
# Author: Yuxin Wu
# Author: Yuxin Wu
import
multiprocessing
import
multiprocessing
import
numpy
as
np
import
random
import
random
import
time
import
time
from
six.moves
import
queue
from
six.moves
import
queue
...
@@ -19,7 +21,8 @@ def play_one_episode(env, func, render=False):
...
@@ -19,7 +21,8 @@ def play_one_episode(env, func, render=False):
"""
"""
Map from observation to action, with 0.01 greedy.
Map from observation to action, with 0.01 greedy.
"""
"""
act
=
func
(
s
[
None
,
:,
:,
:])[
0
][
0
]
.
argmax
()
s
=
np
.
expand_dims
(
s
,
0
)
# batch
act
=
func
(
s
)[
0
][
0
]
.
argmax
()
if
random
.
random
()
<
0.01
:
if
random
.
random
()
<
0.01
:
spc
=
env
.
action_space
spc
=
env
.
action_space
act
=
spc
.
sample
()
act
=
spc
.
sample
()
...
...
examples/DeepQNetwork/expreplay.py
View file @
0b561b3b
...
@@ -23,19 +23,21 @@ Experience = namedtuple('Experience',
...
@@ -23,19 +23,21 @@ Experience = namedtuple('Experience',
class
ReplayMemory
(
object
):
class
ReplayMemory
(
object
):
def
__init__
(
self
,
max_size
,
state_shape
,
history_len
):
def
__init__
(
self
,
max_size
,
state_shape
,
history_len
):
"""
Args:
state_shape (tuple[int]): shape (without history) of state
"""
self
.
max_size
=
int
(
max_size
)
self
.
max_size
=
int
(
max_size
)
self
.
state_shape
=
state_shape
self
.
state_shape
=
state_shape
assert
len
(
state_shape
)
==
3
,
state_shape
assert
len
(
state_shape
)
in
[
1
,
2
,
3
],
state_shape
# self._state_transpose = list(range(1, len(state_shape) + 1)) + [0]
self
.
_output_shape
=
self
.
state_shape
+
(
history_len
+
1
,
)
self
.
_channel
=
state_shape
[
2
]
if
len
(
state_shape
)
==
3
else
1
self
.
_shape3d
=
(
state_shape
[
0
],
state_shape
[
1
],
self
.
_channel
*
(
history_len
+
1
))
self
.
history_len
=
int
(
history_len
)
self
.
history_len
=
int
(
history_len
)
state_shape
=
(
self
.
max_size
,)
+
state_shape
all_
state_shape
=
(
self
.
max_size
,)
+
state_shape
logger
.
info
(
"Creating experience replay buffer of {:.1f} GB ... "
logger
.
info
(
"Creating experience replay buffer of {:.1f} GB ... "
"use a smaller buffer if you don't have enough CPU memory."
.
format
(
"use a smaller buffer if you don't have enough CPU memory."
.
format
(
np
.
prod
(
state_shape
)
/
1024.0
**
3
))
np
.
prod
(
all_
state_shape
)
/
1024.0
**
3
))
self
.
state
=
np
.
zeros
(
state_shape
,
dtype
=
'uint8'
)
self
.
state
=
np
.
zeros
(
all_
state_shape
,
dtype
=
'uint8'
)
self
.
action
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'int32'
)
self
.
action
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'int32'
)
self
.
reward
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'float32'
)
self
.
reward
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'float32'
)
self
.
isOver
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'bool'
)
self
.
isOver
=
np
.
zeros
((
self
.
max_size
,),
dtype
=
'bool'
)
...
@@ -70,7 +72,8 @@ class ReplayMemory(object):
...
@@ -70,7 +72,8 @@ class ReplayMemory(object):
def
sample
(
self
,
idx
):
def
sample
(
self
,
idx
):
""" return a tuple of (s,r,a,o),
""" return a tuple of (s,r,a,o),
where s is of shape [H, W, (hist_len+1) * channel]"""
where s is of shape self._output_shape, which is
[H, W, (hist_len+1) * channel] if input is (H, W, channel)"""
idx
=
(
self
.
_curr_pos
+
idx
)
%
self
.
_curr_size
idx
=
(
self
.
_curr_pos
+
idx
)
%
self
.
_curr_size
k
=
self
.
history_len
+
1
k
=
self
.
history_len
+
1
if
idx
+
k
<=
self
.
_curr_size
:
if
idx
+
k
<=
self
.
_curr_size
:
...
@@ -95,8 +98,8 @@ class ReplayMemory(object):
...
@@ -95,8 +98,8 @@ class ReplayMemory(object):
state
=
copy
.
deepcopy
(
state
)
state
=
copy
.
deepcopy
(
state
)
state
[:
k
+
1
]
.
fill
(
0
)
state
[:
k
+
1
]
.
fill
(
0
)
break
break
# move the first dim to the last
# move the first dim
(history)
to the last
state
=
state
.
transpose
(
1
,
2
,
0
,
3
)
.
reshape
(
self
.
_shape3d
)
state
=
np
.
moveaxis
(
state
,
0
,
-
1
)
return
(
state
,
reward
[
-
2
],
action
[
-
2
],
isOver
[
-
2
])
return
(
state
,
reward
[
-
2
],
action
[
-
2
],
isOver
[
-
2
])
def
_slice
(
self
,
arr
,
start
,
end
):
def
_slice
(
self
,
arr
,
start
,
end
):
...
@@ -140,13 +143,13 @@ class ExpReplay(DataFlow, Callback):
...
@@ -140,13 +143,13 @@ class ExpReplay(DataFlow, Callback):
predictor_io_names (tuple of list of str): input/output names to
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
predict Q value from state.
player (gym.Env): the player.
player (gym.Env): the player.
state_shape (tuple):
h, w, c
state_shape (tuple):
history_len (int): length of history frames to concat. Zero-filled
history_len (int): length of history frames to concat. Zero-filled
initial frames.
initial frames.
update_frequency (int): number of new transitions to add to memory
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
after sampling a batch of transitions for training.
"""
"""
assert
len
(
state_shape
)
==
3
,
state_shape
assert
len
(
state_shape
)
in
[
1
,
2
,
3
]
,
state_shape
init_memory_size
=
int
(
init_memory_size
)
init_memory_size
=
int
(
init_memory_size
)
for
k
,
v
in
locals
()
.
items
():
for
k
,
v
in
locals
()
.
items
():
...
@@ -207,7 +210,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -207,7 +210,7 @@ class ExpReplay(DataFlow, Callback):
# build a history state
# build a history state
history
=
self
.
mem
.
recent_state
()
history
=
self
.
mem
.
recent_state
()
history
.
append
(
old_s
)
history
.
append
(
old_s
)
history
=
np
.
concatenate
(
history
,
axis
=-
1
)
# H,W,HistxC
history
=
np
.
stack
(
history
,
axis
=-
1
)
# state_shape + (Hist,)
history
=
np
.
expand_dims
(
history
,
axis
=
0
)
history
=
np
.
expand_dims
(
history
,
axis
=
0
)
# assume batched network
# assume batched network
...
@@ -216,7 +219,9 @@ class ExpReplay(DataFlow, Callback):
...
@@ -216,7 +219,9 @@ class ExpReplay(DataFlow, Callback):
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
self
.
_current_ob
,
reward
,
isOver
,
info
=
self
.
player
.
step
(
act
)
self
.
_current_game_score
.
feed
(
reward
)
self
.
_current_game_score
.
feed
(
reward
)
if
isOver
:
if
isOver
:
if
info
[
'ale.lives'
]
==
0
:
# only record score when a whole game is over (not when an episode is over)
# handle ale-specific information
if
info
.
get
(
'ale.lives'
,
-
1
)
==
0
:
# only record score when a whole game is over (not when an episode is over)
self
.
_player_scores
.
feed
(
self
.
_current_game_score
.
sum
)
self
.
_player_scores
.
feed
(
self
.
_current_game_score
.
sum
)
self
.
_current_game_score
.
reset
()
self
.
_current_game_score
.
reset
()
self
.
player
.
reset
()
self
.
player
.
reset
()
...
@@ -226,6 +231,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -226,6 +231,7 @@ class ExpReplay(DataFlow, Callback):
import
cv2
import
cv2
def
view_state
(
comb_state
):
def
view_state
(
comb_state
):
# this function assumes comb_state is 3D
state
=
comb_state
[:,
:,
:
-
1
]
state
=
comb_state
[:,
:,
:
-
1
]
next_state
=
comb_state
[:,
:,
1
:]
next_state
=
comb_state
[:,
:,
1
:]
r
=
np
.
concatenate
([
state
[:,
:,
k
]
for
k
in
range
(
self
.
history_len
)],
axis
=
1
)
r
=
np
.
concatenate
([
state
[:,
:,
k
]
for
k
in
range
(
self
.
history_len
)],
axis
=
1
)
...
...
tensorpack/callbacks/base.py
View file @
0b561b3b
...
@@ -44,10 +44,16 @@ class Callback(object):
...
@@ -44,10 +44,16 @@ class Callback(object):
_chief_only
=
True
_chief_only
=
True
name_scope
=
""
"""
A name scope for ops created inside this callback.
By default to the name of the class, but can be set per-instance.
"""
def
setup_graph
(
self
,
trainer
):
def
setup_graph
(
self
,
trainer
):
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
graph
=
tf
.
get_default_graph
()
scope_name
=
type
(
self
)
.
__name__
scope_name
=
self
.
name_scope
or
type
(
self
)
.
__name__
scope_name
=
scope_name
.
replace
(
'_'
,
''
)
scope_name
=
scope_name
.
replace
(
'_'
,
''
)
with
tf
.
name_scope
(
scope_name
):
with
tf
.
name_scope
(
scope_name
):
self
.
_setup_graph
()
self
.
_setup_graph
()
...
...
tensorpack/input_source/input_source.py
View file @
0b561b3b
...
@@ -251,11 +251,13 @@ class QueueInput(FeedfreeInput):
...
@@ -251,11 +251,13 @@ class QueueInput(FeedfreeInput):
# in TF there is no API to get queue capacity, so we can only summary the size
# in TF there is no API to get queue capacity, so we can only summary the size
size
=
tf
.
cast
(
self
.
queue
.
size
(),
tf
.
float32
,
name
=
'queue_size'
)
size
=
tf
.
cast
(
self
.
queue
.
size
(),
tf
.
float32
,
name
=
'queue_size'
)
size_ema_op
=
add_moving_summary
(
size
,
collection
=
None
,
decay
=
0.5
)[
0
]
.
op
size_ema_op
=
add_moving_summary
(
size
,
collection
=
None
,
decay
=
0.5
)[
0
]
.
op
ret
urn
RunOp
(
ret
=
RunOp
(
lambda
:
size_ema_op
,
lambda
:
size_ema_op
,
run_before
=
False
,
run_before
=
False
,
run_as_trigger
=
False
,
run_as_trigger
=
False
,
run_step
=
True
)
run_step
=
True
)
ret
.
name_scope
=
"InputSource/EMA"
return
ret
def
_get_callbacks
(
self
):
def
_get_callbacks
(
self
):
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
...
...
tensorpack/train/trainers.py
View file @
0b561b3b
...
@@ -194,6 +194,7 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
...
@@ -194,6 +194,7 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
run_before
=
True
,
run_before
=
True
,
run_as_trigger
=
self
.
BROADCAST_EVERY_EPOCH
,
run_as_trigger
=
self
.
BROADCAST_EVERY_EPOCH
,
verbose
=
True
)
verbose
=
True
)
cb
.
name_scope
=
"SyncVariables"
return
[
cb
]
return
[
cb
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment