Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ff40a873
Commit
ff40a873
authored
May 28, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
major refactor RL
parent
40e6a223
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
146 additions
and
106 deletions
+146
-106
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+2
-3
tensorpack/RL/__init__.py
tensorpack/RL/__init__.py
+21
-0
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+4
-3
tensorpack/RL/common.py
tensorpack/RL/common.py
+42
-0
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+68
-0
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+8
-99
tensorpack/predict/__init__.py
tensorpack/predict/__init__.py
+1
-1
No files found.
examples/Atari2600/DQN.py
View file @
ff40a873
...
...
@@ -18,12 +18,11 @@ from tensorpack.models import *
from
tensorpack.utils
import
*
from
tensorpack.utils.concurrency
import
ensure_proc_terminate
,
subproc_call
from
tensorpack.utils.stat
import
*
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
Parallel
PredictWorker
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
,
MultiProcess
PredictWorker
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow.dataset
import
AtariPlayer
from
tensorpack.dataflow.RL
import
ExpReplay
from
tensorpack.RL
import
AtariPlayer
,
ExpReplay
"""
Implement DQN in:
...
...
tensorpack/RL/__init__.py
0 → 100644
View file @
ff40a873
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
pkgutil
import
walk_packages
import
importlib
import
os
import
os.path
def
_global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
del
globals
()[
name
]
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
for
_
,
module_name
,
_
in
walk_packages
(
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
_global_import
(
module_name
)
tensorpack/
dataflow/dataset
/atari.py
→
tensorpack/
RL
/atari.py
View file @
ff40a873
...
...
@@ -8,9 +8,10 @@ import time
import
os
import
cv2
from
collections
import
deque
from
...utils
import
get_rng
,
logger
from
...utils.stat
import
StatCounter
from
..RL
import
RLEnvironment
from
..utils
import
get_rng
,
logger
from
..utils.stat
import
StatCounter
from
.envbase
import
RLEnvironment
try
:
from
ale_python_interface
import
ALEInterface
...
...
tensorpack/RL/common.py
0 → 100644
View file @
ff40a873
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
from
collections
import
deque
from
.envbase
import
ProxyPlayer
__all__
=
[
'HistoryFramePlayer'
]
class
HistoryFramePlayer
(
ProxyPlayer
):
def
__init__
(
self
,
player
,
hist_len
):
super
(
HistoryFramePlayer
,
self
)
.
__init__
(
player
)
self
.
history
=
deque
(
maxlen
=
hist_len
)
s
=
self
.
player
.
current_state
()
self
.
history
.
append
(
s
)
def
current_state
(
self
):
assert
len
(
self
.
history
)
!=
0
diff_len
=
self
.
history
.
maxlen
-
len
(
self
.
history
)
if
diff_len
==
0
:
return
np
.
concatenate
(
self
.
history
,
axis
=
2
)
zeros
=
[
np
.
zeros_like
(
self
.
history
[
0
])
for
k
in
range
(
diff_len
)]
for
k
in
self
.
history
:
zeros
.
append
(
k
)
return
np
.
concatenate
(
zeros
,
axis
=
2
)
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
s
=
self
.
player
.
current_state
()
self
.
history
.
append
(
s
)
if
isOver
:
# s would be a new episode
self
.
history
.
clear
()
self
.
history
.
append
(
s
)
return
(
r
,
isOver
)
class
AvoidNoOpPlayer
(
ProxyPlayer
):
pass
# TODO
tensorpack/RL/envbase.py
0 → 100644
View file @
ff40a873
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: envbase.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
__all__
=
[
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'ProxyPlayer'
]
class
RLEnvironment
(
object
):
__meta__
=
ABCMeta
def
__init__
(
self
):
self
.
reset_stat
()
@
abstractmethod
def
current_state
(
self
):
"""
Observe, return a state representation
"""
@
abstractmethod
def
action
(
self
,
act
):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
@
abstractmethod
def
get_stat
(
self
):
"""
return a dict of statistics (e.g., score) after running for a while
"""
def
reset_stat
(
self
):
""" reset the statistics counter"""
self
.
stats
=
defaultdict
(
list
)
class
NaiveRLEnvironment
(
RLEnvironment
):
""" for testing only"""
def
__init__
(
self
):
self
.
k
=
0
def
current_state
(
self
):
self
.
k
+=
1
return
self
.
k
def
action
(
self
,
act
):
self
.
k
=
act
return
(
self
.
k
,
self
.
k
>
10
)
class
ProxyPlayer
(
RLEnvironment
):
def
__init__
(
self
,
player
):
self
.
player
=
player
def
get_stat
(
self
):
return
self
.
player
.
get_stat
()
def
reset_stat
(
self
):
self
.
player
.
reset_stat
()
def
current_state
(
self
):
return
self
.
player
.
current_state
()
def
action
(
self
,
act
):
return
self
.
player
.
action
(
act
)
tensorpack/
dataflow/RL
.py
→
tensorpack/
RL/expreplay
.py
View file @
ff40a873
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File:
RL
.py
# File:
expreplay
.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
abstractmethod
,
ABCMeta
import
random
import
numpy
as
np
from
collections
import
deque
,
namedtuple
,
defaultdict
from
collections
import
deque
,
namedtuple
from
tqdm
import
tqdm
import
cv2
import
six
from
.
base
import
DataFlow
from
tensorpack
.utils
import
*
from
tensorpack
.callbacks.base
import
Callback
from
.
.dataflow
import
DataFlow
from
.
.utils
import
*
from
.
.callbacks.base
import
Callback
"""
Implement RL-related data preprocessing
"""
__all__
=
[
'ExpReplay'
,
'RLEnvironment'
,
'NaiveRLEnvironment'
,
'HistoryFramePlayer'
]
__all__
=
[
'ExpReplay'
]
Experience
=
namedtuple
(
'Experience'
,
[
'state'
,
'action'
,
'reward'
,
'isOver'
])
class
RLEnvironment
(
object
):
__meta__
=
ABCMeta
def
__init__
(
self
):
self
.
reset_stat
()
@
abstractmethod
def
current_state
(
self
):
"""
Observe, return a state representation
"""
@
abstractmethod
def
action
(
self
,
act
):
"""
Perform an action
:params act: the action
:returns: (reward, isOver)
"""
@
abstractmethod
def
get_stat
(
self
):
"""
return a dict of statistics (e.g., score) after running for a while
"""
def
reset_stat
(
self
):
""" reset the statistics counter"""
self
.
stats
=
defaultdict
(
list
)
class
NaiveRLEnvironment
(
RLEnvironment
):
""" for testing only"""
def
__init__
(
self
):
self
.
k
=
0
def
current_state
(
self
):
self
.
k
+=
1
return
self
.
k
def
action
(
self
,
act
):
self
.
k
=
act
return
(
self
.
k
,
self
.
k
>
10
)
class
ProxyPlayer
(
RLEnvironment
):
def
__init__
(
self
,
player
):
self
.
player
=
player
def
get_stat
(
self
):
return
self
.
player
.
get_stat
()
def
reset_stat
(
self
):
self
.
player
.
reset_stat
()
def
current_state
(
self
):
return
self
.
player
.
current_state
()
def
action
(
self
,
act
):
return
self
.
player
.
action
(
act
)
class
HistoryFramePlayer
(
ProxyPlayer
):
def
__init__
(
self
,
player
,
hist_len
):
super
(
HistoryFramePlayer
,
self
)
.
__init__
(
player
)
self
.
history
=
deque
(
maxlen
=
hist_len
)
s
=
self
.
player
.
current_state
()
self
.
history
.
append
(
s
)
def
current_state
(
self
):
assert
len
(
self
.
history
)
!=
0
diff_len
=
self
.
history
.
maxlen
-
len
(
self
.
history
)
if
diff_len
==
0
:
return
np
.
concatenate
(
self
.
history
,
axis
=
2
)
zeros
=
[
np
.
zeros_like
(
self
.
history
[
0
])
for
k
in
range
(
diff_len
)]
for
k
in
self
.
history
:
zeros
.
append
(
k
)
return
np
.
concatenate
(
zeros
,
axis
=
2
)
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
s
=
self
.
player
.
current_state
()
self
.
history
.
append
(
s
)
if
isOver
:
# s would be a new episode
self
.
history
.
clear
()
self
.
history
.
append
(
s
)
return
(
r
,
isOver
)
class
ExpReplay
(
DataFlow
,
Callback
):
"""
Implement experience replay in the paper
...
...
@@ -182,6 +90,7 @@ class ExpReplay(DataFlow, Callback):
while
True
:
batch_exp
=
[
self
.
sample_one
()
for
_
in
range
(
self
.
batch_size
)]
#import cv2
#def view_state(state, next_state):
#""" for debugging state representation"""
#r = np.concatenate([state[:,:,k] for k in range(self.history_len)], axis=1)
...
...
@@ -253,7 +162,7 @@ class ExpReplay(DataFlow, Callback):
self
.
player
.
reset_stat
()
if
__name__
==
'__main__'
:
from
tensorpack.dataflow.dataset
import
AtariPlayer
from
.atari
import
AtariPlayer
import
sys
predictor
=
lambda
x
:
np
.
array
([
1
,
1
,
1
,
1
])
predictor
.
initialized
=
False
...
...
tensorpack/predict/__init__.py
View file @
ff40a873
...
...
@@ -9,9 +9,9 @@ import os.path
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
del
globals
()[
name
]
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
del
globals
()[
name
]
for
_
,
module_name
,
_
in
walk_packages
(
[
os
.
path
.
dirname
(
__file__
)]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment