Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f15c2181
Commit
f15c2181
authored
May 28, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
simulator, prevent stuck
parent
ff40a873
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
178 additions
and
3 deletions
+178
-3
tensorpack/RL/common.py
tensorpack/RL/common.py
+25
-3
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+5
-0
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+148
-0
No files found.
tensorpack/RL/common.py
View file @
f15c2181
...
...
@@ -8,9 +8,10 @@ import numpy as np
from
collections
import
deque
from
.envbase
import
ProxyPlayer
__all__
=
[
'HistoryFramePlayer'
]
__all__
=
[
'HistoryFramePlayer'
,
'PreventStuckPlayer'
]
class
HistoryFramePlayer
(
ProxyPlayer
):
""" Include history frames in state, or use black images"""
def
__init__
(
self
,
player
,
hist_len
):
super
(
HistoryFramePlayer
,
self
)
.
__init__
(
player
)
self
.
history
=
deque
(
maxlen
=
hist_len
)
...
...
@@ -38,5 +39,26 @@ class HistoryFramePlayer(ProxyPlayer):
self
.
history
.
append
(
s
)
return
(
r
,
isOver
)
class
AvoidNoOpPlayer
(
ProxyPlayer
):
pass
# TODO
class
PreventStuckPlayer
(
ProxyPlayer
):
""" Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout
where the agent needs to press the 'start' button to start playing.
"""
# TODO hash the state as well?
def
__init__
(
self
,
player
,
nr_repeat
,
action
):
"""
:param nr_repeat: trigger the 'action' after this many of repeated action
:param action: the action to be triggered to get out of stuck
"""
super
(
PreventStuckPlayer
,
self
)
.
__init__
(
player
)
self
.
act_que
=
deque
(
maxlen
=
nr_repeat
)
self
.
trigger_action
=
action
def
action
(
self
,
act
):
self
.
act_que
.
append
(
act
)
if
self
.
act_que
.
count
(
self
.
act_que
[
0
])
==
self
.
act_que
.
maxlen
:
act
=
self
.
trigger_action
r
,
isOver
=
self
.
player
.
action
(
act
)
if
isOver
:
self
.
act_que
.
clear
()
return
(
r
,
isOver
)
tensorpack/RL/envbase.py
View file @
f15c2181
...
...
@@ -51,6 +51,7 @@ class NaiveRLEnvironment(RLEnvironment):
return
(
self
.
k
,
self
.
k
>
10
)
class
ProxyPlayer
(
RLEnvironment
):
""" Serve as a proxy another player """
def
__init__
(
self
,
player
):
self
.
player
=
player
...
...
@@ -66,3 +67,7 @@ class ProxyPlayer(RLEnvironment):
def
action
(
self
,
act
):
return
self
.
player
.
action
(
act
)
@
property
def
stats
(
self
):
return
self
.
player
.
stats
tensorpack/RL/simulator.py
0 → 100644
View file @
f15c2181
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
multiprocessing
import
threading
import
zmq
import
weakref
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
,
namedtuple
from
tensorpack.utils.serialize
import
*
from
tensorpack.utils.concurrency
import
*
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
]
class
SimulatorProcess
(
multiprocessing
.
Process
):
""" A process that simulates a player """
__meta__
=
ABCMeta
def
__init__
(
self
,
idx
,
server_name
):
"""
:param idx: idx of this process
:param player: An RLEnvironment
:param server_name: name of the server socket
"""
super
(
SimulatorProcess
,
self
)
.
__init__
()
self
.
idx
=
int
(
idx
)
self
.
server_name
=
server_name
def
run
(
self
):
player
=
self
.
_build_player
()
context
=
zmq
.
Context
()
socket
=
context
.
socket
(
zmq
.
REQ
)
socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
socket
.
connect
(
self
.
server_name
)
while
True
:
state
=
player
.
current_state
()
socket
.
send
(
dumps
(
state
),
copy
=
False
)
action
=
loads
(
socket
.
recv
(
copy
=
False
))
reward
,
isOver
=
player
.
action
(
action
)
socket
.
send
(
dumps
((
reward
,
isOver
)),
copy
=
False
)
noop
=
socket
.
recv
(
copy
=
False
)
@
abstractmethod
def
_build_player
(
self
):
pass
class
SimulatorMaster
(
threading
.
Thread
):
""" A base thread to communicate with all simulator processes.
It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished.
"""
__metaclass__
=
ABCMeta
def
__init__
(
self
,
server_name
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
self
.
server_name
=
server_name
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
socket
.
bind
(
self
.
server_name
)
self
.
daemon
=
True
def
clean_context
(
sok
,
context
):
sok
.
close
()
context
.
term
()
import
atexit
atexit
.
register
(
clean_context
,
self
.
socket
,
self
.
context
)
def
run
(
self
):
class
ClientState
(
object
):
def
__init__
(
self
):
self
.
protocol_state
=
0
# state in communication
self
.
memory
=
[]
# list of Experience
class
Experience
(
object
):
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
):
self
.
state
=
state
self
.
action
=
action
self
.
reward
=
reward
self
.
clients
=
defaultdict
(
ClientState
)
while
True
:
ident
,
_
,
msg
=
self
.
socket
.
recv_multipart
()
client
=
self
.
clients
[
ident
]
if
client
.
protocol_state
==
0
:
# state-action
state
=
loads
(
msg
)
action
=
self
.
_get_action
(
state
)
self
.
socket
.
send_multipart
([
ident
,
_
,
dumps
(
action
)])
client
.
memory
.
append
(
Experience
(
state
,
action
,
None
))
else
:
# reward-response
reward
,
isOver
=
loads
(
msg
)
assert
isinstance
(
isOver
,
bool
)
client
.
memory
[
-
1
]
.
reward
=
reward
if
isOver
:
self
.
_on_episode_over
(
client
)
else
:
self
.
_on_datapoint
(
client
)
self
.
socket
.
send_multipart
([
ident
,
_
,
dumps
(
'Thanks'
)])
client
.
protocol_state
=
1
-
client
.
protocol_state
# flip the state
@
abstractmethod
def
_get_action
(
self
,
state
):
"""response to state"""
@
abstractmethod
def
_on_episode_over
(
self
,
client
):
""" callback when the client just finished an episode.
You may want to clear the client's memory in this callback.
"""
def
_on_datapoint
(
self
,
client
):
""" callback when the client just finished a transition
"""
def
__del__
(
self
):
self
.
socket
.
close
()
self
.
context
.
term
()
if
__name__
==
'__main__'
:
import
random
from
tensorpack.RL
import
NaiveRLEnvironment
class
NaiveSimulator
(
SimulatorProcess
):
def
_build_player
(
self
):
return
NaiveRLEnvironment
()
class
NaiveActioner
(
SimulatorActioner
):
def
_get_action
(
self
,
state
):
time
.
sleep
(
1
)
return
random
.
randint
(
1
,
12
)
def
_on_episode_over
(
self
,
client
):
#print("Over: ", client.memory)
client
.
memory
=
[]
client
.
state
=
0
name
=
'ipc://whatever'
procs
=
[
NaiveSimulator
(
k
,
name
)
for
k
in
range
(
10
)]
[
k
.
start
()
for
k
in
procs
]
th
=
NaiveActioner
(
name
)
ensure_proc_terminate
(
procs
)
th
.
start
()
import
time
time
.
sleep
(
100
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment