Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ada058f3
Commit
ada058f3
authored
May 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update docs in DQN, remove a 4-frame hard-coded assumption (#268)
parent
c3effc3c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
9 deletions
+23
-9
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+4
-2
examples/DeepQNetwork/DQNModel.py
examples/DeepQNetwork/DQNModel.py
+4
-4
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+14
-2
tensorpack/dataflow/imgaug/convert.py
tensorpack/dataflow/imgaug/convert.py
+1
-1
No files found.
examples/DeepQNetwork/DQN.py
View file @
ada058f3
...
@@ -32,7 +32,6 @@ IMAGE_SIZE = (84, 84)
...
@@ -32,7 +32,6 @@ IMAGE_SIZE = (84, 84)
FRAME_HISTORY
=
4
FRAME_HISTORY
=
4
ACTION_REPEAT
=
4
ACTION_REPEAT
=
4
CHANNEL
=
FRAME_HISTORY
GAMMA
=
0.99
GAMMA
=
0.99
INIT_EXPLORATION
=
1
INIT_EXPLORATION
=
1
...
@@ -54,8 +53,11 @@ def get_player(viz=False, train=False):
...
@@ -54,8 +53,11 @@ def get_player(viz=False, train=False):
pl
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
pl
=
AtariPlayer
(
ROM_FILE
,
frame_skip
=
ACTION_REPEAT
,
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
live_lost_as_eoe
=
train
)
image_shape
=
IMAGE_SIZE
[::
-
1
],
viz
=
viz
,
live_lost_as_eoe
=
train
)
if
not
train
:
if
not
train
:
# create a new axis to stack history on
pl
=
MapPlayerState
(
pl
,
lambda
im
:
im
[:,
:,
np
.
newaxis
])
pl
=
MapPlayerState
(
pl
,
lambda
im
:
im
[:,
:,
np
.
newaxis
])
# in training, history is taken care of in expreplay buffer
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
LimitLengthPlayer
(
pl
,
30000
)
pl
=
LimitLengthPlayer
(
pl
,
30000
)
return
pl
return
pl
...
@@ -63,7 +65,7 @@ def get_player(viz=False, train=False):
...
@@ -63,7 +65,7 @@ def get_player(viz=False, train=False):
class
Model
(
DQNModel
):
class
Model
(
DQNModel
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
)
.
__init__
(
IMAGE_SIZE
,
CHANNEL
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
super
(
Model
,
self
)
.
__init__
(
IMAGE_SIZE
,
FRAME_HISTORY
,
METHOD
,
NUM_ACTIONS
,
GAMMA
)
def
_get_DQN_prediction
(
self
,
image
):
def
_get_DQN_prediction
(
self
,
image
):
""" image: [0,255]"""
""" image: [0,255]"""
...
...
examples/DeepQNetwork/DQNModel.py
View file @
ada058f3
...
@@ -21,8 +21,8 @@ class Model(ModelDesc):
...
@@ -21,8 +21,8 @@ class Model(ModelDesc):
self
.
gamma
=
gamma
self
.
gamma
=
gamma
def
_get_inputs
(
self
):
def
_get_inputs
(
self
):
#
use a combined state, where the first channels are the current state,
#
Use a combined state for efficiency.
#
and the last 4 channels are the next state
#
The first h channels are the current state, and the last h channels are the next state.
return
[
InputDesc
(
tf
.
uint8
,
return
[
InputDesc
(
tf
.
uint8
,
(
None
,)
+
self
.
image_shape
+
(
self
.
channel
+
1
,),
(
None
,)
+
self
.
image_shape
+
(
self
.
channel
+
1
,),
'comb_state'
),
'comb_state'
),
...
@@ -37,13 +37,13 @@ class Model(ModelDesc):
...
@@ -37,13 +37,13 @@ class Model(ModelDesc):
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
comb_state
,
action
,
reward
,
isOver
=
inputs
comb_state
,
action
,
reward
,
isOver
=
inputs
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
comb_state
=
tf
.
cast
(
comb_state
,
tf
.
float32
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'state'
)
state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
0
],
[
-
1
,
-
1
,
-
1
,
self
.
channel
],
name
=
'state'
)
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
)
self
.
predict_value
=
self
.
_get_DQN_prediction
(
state
)
if
not
get_current_tower_context
()
.
is_training
:
if
not
get_current_tower_context
()
.
is_training
:
return
return
reward
=
tf
.
clip_by_value
(
reward
,
-
1
,
1
)
reward
=
tf
.
clip_by_value
(
reward
,
-
1
,
1
)
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
4
],
name
=
'next_state'
)
next_state
=
tf
.
slice
(
comb_state
,
[
0
,
0
,
0
,
1
],
[
-
1
,
-
1
,
-
1
,
self
.
channel
],
name
=
'next_state'
)
action_onehot
=
tf
.
one_hot
(
action
,
self
.
num_actions
,
1.0
,
0.0
)
action_onehot
=
tf
.
one_hot
(
action
,
self
.
num_actions
,
1.0
,
0.0
)
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
pred_action_value
=
tf
.
reduce_sum
(
self
.
predict_value
*
action_onehot
,
1
)
# N,
...
...
examples/DeepQNetwork/expreplay.py
View file @
ada058f3
...
@@ -113,7 +113,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -113,7 +113,7 @@ class ExpReplay(DataFlow, Callback):
This implementation provides the interface as a :class:`DataFlow`.
This implementation provides the interface as a :class:`DataFlow`.
This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
This implementation
only works with Q-learning. It
assumes that state is
This implementation assumes that state is
batch-able, and the network takes batched inputs.
batch-able, and the network takes batched inputs.
"""
"""
...
@@ -171,6 +171,18 @@ class ExpReplay(DataFlow, Callback):
...
@@ -171,6 +171,18 @@ class ExpReplay(DataFlow, Callback):
pbar
.
update
()
pbar
.
update
()
self
.
_init_memory_flag
.
set
()
self
.
_init_memory_flag
.
set
()
# quickly fill the memory for debug
def
_fake_init_memory
(
self
):
from
copy
import
deepcopy
with
get_tqdm
(
total
=
self
.
init_memory_size
)
as
pbar
:
while
len
(
self
.
mem
)
<
5
:
self
.
_populate_exp
()
pbar
.
update
()
while
len
(
self
.
mem
)
<
self
.
init_memory_size
:
self
.
mem
.
append
(
deepcopy
(
self
.
mem
.
_hist
[
0
]))
pbar
.
update
()
self
.
_init_memory_flag
.
set
()
def
_populate_exp
(
self
):
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
""" populate a transition by epsilon-greedy"""
old_s
=
self
.
player
.
current_state
()
old_s
=
self
.
player
.
current_state
()
...
@@ -188,7 +200,7 @@ class ExpReplay(DataFlow, Callback):
...
@@ -188,7 +200,7 @@ class ExpReplay(DataFlow, Callback):
reward
,
isOver
=
self
.
player
.
action
(
act
)
reward
,
isOver
=
self
.
player
.
action
(
act
)
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
self
.
mem
.
append
(
Experience
(
old_s
,
act
,
reward
,
isOver
))
def
debug_sample
(
self
,
sample
):
def
_
debug_sample
(
self
,
sample
):
import
cv2
import
cv2
def
view_state
(
comb_state
):
def
view_state
(
comb_state
):
...
...
tensorpack/dataflow/imgaug/convert.py
View file @
ada058f3
...
@@ -12,7 +12,7 @@ __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
...
@@ -12,7 +12,7 @@ __all__ = ['ColorSpace', 'Grayscale', 'ToUint8', 'ToFloat32']
class
ColorSpace
(
ImageAugmentor
):
class
ColorSpace
(
ImageAugmentor
):
""" Convert into another colorspace. """
""" Convert into another colorspace. """
def
__init__
(
self
,
mode
=
cv2
.
COLOR_BGR2GRAY
,
keepdims
=
True
):
def
__init__
(
self
,
mode
,
keepdims
=
True
):
"""
"""
Args:
Args:
mode: opencv colorspace conversion code (e.g., `cv2.COLOR_BGR2HSV`)
mode: opencv colorspace conversion code (e.g., `cv2.COLOR_BGR2HSV`)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment