Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6e1f395d
Commit
6e1f395d
authored
May 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
exp_replay as a callback
parent
ddf737d7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
16 deletions
+19
-16
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+4
-15
tensorpack/dataflow/RL.py
tensorpack/dataflow/RL.py
+15
-1
No files found.
examples/Atari2600/DQN.py
View file @
6e1f395d
...
@@ -43,12 +43,11 @@ INIT_EXPLORATION = 1
...
@@ -43,12 +43,11 @@ INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL
=
0.0025
EXPLORATION_EPOCH_ANNEAL
=
0.0025
END_EXPLORATION
=
0.1
END_EXPLORATION
=
0.1
INIT_MEMORY_SIZE
=
50000
MEMORY_SIZE
=
1e6
MEMORY_SIZE
=
1e6
INIT_MEMORY_SIZE
=
50000
STEP_PER_EPOCH
=
10000
STEP_PER_EPOCH
=
10000
EVAL_EPISODE
=
100
EVAL_EPISODE
=
100
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
assert
NUM_ACTIONS
is
not
None
assert
NUM_ACTIONS
is
not
None
...
@@ -131,18 +130,6 @@ class TargetNetworkUpdator(Callback):
...
@@ -131,18 +130,6 @@ class TargetNetworkUpdator(Callback):
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
_update
()
self
.
_update
()
class
ExpReplayController
(
Callback
):
def
__init__
(
self
,
d
):
self
.
d
=
d
def
_before_train
(
self
):
self
.
d
.
init_memory
()
def
_trigger_epoch
(
self
):
if
self
.
d
.
exploration
>
END_EXPLORATION
:
self
.
d
.
exploration
-=
EXPLORATION_EPOCH_ANNEAL
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
d
.
exploration
))
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
tot_reward
=
0
tot_reward
=
0
que
=
deque
(
maxlen
=
30
)
que
=
deque
(
maxlen
=
30
)
...
@@ -251,6 +238,8 @@ def get_config(romfile):
...
@@ -251,6 +238,8 @@ def get_config(romfile):
batch_size
=
BATCH_SIZE
,
batch_size
=
BATCH_SIZE
,
populate_size
=
INIT_MEMORY_SIZE
,
populate_size
=
INIT_MEMORY_SIZE
,
exploration
=
INIT_EXPLORATION
,
exploration
=
INIT_EXPLORATION
,
end_exploration
=
END_EXPLORATION
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
reward_clip
=
(
-
1
,
2
))
reward_clip
=
(
-
1
,
2
))
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
...
@@ -277,7 +266,7 @@ def get_config(romfile):
...
@@ -277,7 +266,7 @@ def get_config(romfile):
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
((
dataset_train
,
'exploration'
),
'hyper.txt'
),
HumanHyperParamSetter
((
dataset_train
,
'exploration'
),
'hyper.txt'
),
TargetNetworkUpdator
(
M
),
TargetNetworkUpdator
(
M
),
ExpReplayController
(
dataset_train
)
,
dataset_train
,
PeriodicCallback
(
Evaluator
(),
1
),
PeriodicCallback
(
Evaluator
(),
1
),
]),
]),
session_config
=
get_default_sess_config
(
0.5
),
session_config
=
get_default_sess_config
(
0.5
),
...
...
tensorpack/dataflow/RL.py
View file @
6e1f395d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
from
.base
import
DataFlow
from
.base
import
DataFlow
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
from
tensorpack.callbacks.base
import
Callback
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
random
import
random
...
@@ -12,6 +13,7 @@ import numpy as np
...
@@ -12,6 +13,7 @@ import numpy as np
import
cv2
import
cv2
from
collections
import
deque
,
namedtuple
from
collections
import
deque
,
namedtuple
"""
"""
Implement RL-related data preprocessing
Implement RL-related data preprocessing
"""
"""
...
@@ -28,7 +30,7 @@ def view_state(state):
...
@@ -28,7 +30,7 @@ def view_state(state):
cv2
.
imshow
(
"state"
,
r
)
cv2
.
imshow
(
"state"
,
r
)
cv2
.
waitKey
()
cv2
.
waitKey
()
class
ExpReplay
(
DataFlow
):
class
ExpReplay
(
DataFlow
,
Callback
):
"""
"""
Implement experience replay.
Implement experience replay.
"""
"""
...
@@ -40,6 +42,8 @@ class ExpReplay(DataFlow):
...
@@ -40,6 +42,8 @@ class ExpReplay(DataFlow):
batch_size
=
32
,
batch_size
=
32
,
populate_size
=
50000
,
populate_size
=
50000
,
exploration
=
1
,
exploration
=
1
,
end_exploration
=
0.1
,
exploration_epoch_anneal
=
0.002
,
reward_clip
=
None
):
reward_clip
=
None
):
"""
"""
:param predictor: callabale. called with a state, return a distribution
:param predictor: callabale. called with a state, return a distribution
...
@@ -102,6 +106,16 @@ class ExpReplay(DataFlow):
...
@@ -102,6 +106,16 @@ class ExpReplay(DataFlow):
isOver
[
idx
]
=
b
.
isOver
isOver
[
idx
]
=
b
.
isOver
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
# Callback-related:
def
_before_train
(
self
):
self
.
init_memory
()
def
_trigger_epoch
(
self
):
if
self
.
exploration
>
self
.
end_exploration
:
self
.
exploration
-=
self
.
exploration_epoch_anneal
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
exploration
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment