Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6e1f395d
You need to sign in or sign up before continuing.
Commit
6e1f395d
authored
May 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
exp_replay as a callback
parent
ddf737d7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
16 deletions
+19
-16
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+4
-15
tensorpack/dataflow/RL.py
tensorpack/dataflow/RL.py
+15
-1
No files found.
examples/Atari2600/DQN.py
View file @
6e1f395d
...
...
@@ -43,12 +43,11 @@ INIT_EXPLORATION = 1
EXPLORATION_EPOCH_ANNEAL
=
0.0025
END_EXPLORATION
=
0.1
INIT_MEMORY_SIZE
=
50000
MEMORY_SIZE
=
1e6
INIT_MEMORY_SIZE
=
50000
STEP_PER_EPOCH
=
10000
EVAL_EPISODE
=
100
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
assert
NUM_ACTIONS
is
not
None
...
...
@@ -131,18 +130,6 @@ class TargetNetworkUpdator(Callback):
def
_trigger_epoch
(
self
):
self
.
_update
()
class
ExpReplayController
(
Callback
):
def
__init__
(
self
,
d
):
self
.
d
=
d
def
_before_train
(
self
):
self
.
d
.
init_memory
()
def
_trigger_epoch
(
self
):
if
self
.
d
.
exploration
>
END_EXPLORATION
:
self
.
d
.
exploration
-=
EXPLORATION_EPOCH_ANNEAL
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
d
.
exploration
))
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
tot_reward
=
0
que
=
deque
(
maxlen
=
30
)
...
...
@@ -251,6 +238,8 @@ def get_config(romfile):
batch_size
=
BATCH_SIZE
,
populate_size
=
INIT_MEMORY_SIZE
,
exploration
=
INIT_EXPLORATION
,
end_exploration
=
END_EXPLORATION
,
exploration_epoch_anneal
=
EXPLORATION_EPOCH_ANNEAL
,
reward_clip
=
(
-
1
,
2
))
lr
=
tf
.
Variable
(
0.0025
,
trainable
=
False
,
name
=
'learning_rate'
)
...
...
@@ -277,7 +266,7 @@ def get_config(romfile):
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
),
HumanHyperParamSetter
((
dataset_train
,
'exploration'
),
'hyper.txt'
),
TargetNetworkUpdator
(
M
),
ExpReplayController
(
dataset_train
)
,
dataset_train
,
PeriodicCallback
(
Evaluator
(),
1
),
]),
session_config
=
get_default_sess_config
(
0.5
),
...
...
tensorpack/dataflow/RL.py
View file @
6e1f395d
...
...
@@ -5,6 +5,7 @@
from
.base
import
DataFlow
from
tensorpack.utils
import
*
from
tensorpack.callbacks.base
import
Callback
from
tqdm
import
tqdm
import
random
...
...
@@ -12,6 +13,7 @@ import numpy as np
import
cv2
from
collections
import
deque
,
namedtuple
"""
Implement RL-related data preprocessing
"""
...
...
@@ -28,7 +30,7 @@ def view_state(state):
cv2
.
imshow
(
"state"
,
r
)
cv2
.
waitKey
()
class
ExpReplay
(
DataFlow
):
class
ExpReplay
(
DataFlow
,
Callback
):
"""
Implement experience replay.
"""
...
...
@@ -40,6 +42,8 @@ class ExpReplay(DataFlow):
batch_size
=
32
,
populate_size
=
50000
,
exploration
=
1
,
end_exploration
=
0.1
,
exploration_epoch_anneal
=
0.002
,
reward_clip
=
None
):
"""
:param predictor: callabale. called with a state, return a distribution
...
...
@@ -102,6 +106,16 @@ class ExpReplay(DataFlow):
isOver
[
idx
]
=
b
.
isOver
return
[
state
,
action
,
reward
,
next_state
,
isOver
]
# Callback-related:
def
_before_train
(
self
):
self
.
init_memory
()
def
_trigger_epoch
(
self
):
if
self
.
exploration
>
self
.
end_exploration
:
self
.
exploration
-=
self
.
exploration_epoch_anneal
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
exploration
))
if
__name__
==
'__main__'
:
from
tensorpack.dataflow.dataset
import
AtariDriver
,
AtariPlayer
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment