Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e2194663
Commit
e2194663
authored
Dec 06, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[A3C] Simplify simulator master
parent
29a7da44
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
41 deletions
+30
-41
examples/A3C-Gym/simulator.py
examples/A3C-Gym/simulator.py
+5
-26
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+25
-15
No files found.
examples/A3C-Gym/simulator.py
View file @
e2194663
...
...
@@ -103,6 +103,7 @@ class SimulatorMaster(threading.Thread):
class
ClientState
(
object
):
def
__init__
(
self
):
self
.
memory
=
[]
# list of Experience
self
.
ident
=
None
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
...
...
@@ -143,36 +144,14 @@ class SimulatorMaster(threading.Thread):
while
True
:
msg
=
loads
(
self
.
c2s_socket
.
recv
(
copy
=
False
)
.
bytes
)
ident
,
state
,
reward
,
isOver
=
msg
# TODO check history and warn about dead client
client
=
self
.
clients
[
ident
]
# check if reward&isOver is valid
# in the first message, only state is valid
if
len
(
client
.
memory
)
>
0
:
client
.
memory
[
-
1
]
.
reward
=
reward
if
isOver
:
self
.
_on_episode_over
(
ident
)
else
:
self
.
_on_datapoint
(
ident
)
# feed state and return action
self
.
_on_state
(
state
,
ident
)
if
client
.
ident
is
None
:
client
.
ident
=
ident
# maybe check history and warn about dead client?
self
.
_process_msg
(
client
,
state
,
reward
,
isOver
)
except
zmq
.
ContextTerminated
:
logger
.
info
(
"[Simulator] Context was terminated."
)
@
abstractmethod
def
_on_state
(
self
,
state
,
ident
):
"""response to state sent by ident. Preferrably an async call"""
@
abstractmethod
def
_on_episode_over
(
self
,
client
):
""" callback when the client just finished an episode.
You may want to clear the client's memory in this callback.
"""
def
_on_datapoint
(
self
,
client
):
""" callback when the client just finished a transition
"""
def
__del__
(
self
):
self
.
context
.
destroy
(
linger
=
0
)
...
...
examples/A3C-Gym/train-atari.py
View file @
e2194663
...
...
@@ -158,32 +158,42 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def
_before_train
(
self
):
self
.
async_predictor
.
start
()
def
_on_state
(
self
,
state
,
ident
):
def
_on_state
(
self
,
state
,
client
):
"""
Launch forward prediction for the new state given by some client.
"""
def
cb
(
outputs
):
try
:
distrib
,
value
=
outputs
.
result
()
except
CancelledError
:
logger
.
info
(
"Client {} cancelled."
.
format
(
ident
))
logger
.
info
(
"Client {} cancelled."
.
format
(
client
.
ident
))
return
assert
np
.
all
(
np
.
isfinite
(
distrib
)),
distrib
action
=
np
.
random
.
choice
(
len
(
distrib
),
p
=
distrib
)
client
=
self
.
clients
[
ident
]
client
.
memory
.
append
(
TransitionExperience
(
state
,
action
,
reward
=
None
,
value
=
value
,
prob
=
distrib
[
action
]))
self
.
send_queue
.
put
([
ident
,
dumps
(
action
)])
self
.
send_queue
.
put
([
client
.
ident
,
dumps
(
action
)])
self
.
async_predictor
.
put_task
([
state
],
cb
)
def
_on_episode_over
(
self
,
ident
):
self
.
_parse_memory
(
0
,
ident
,
True
)
def
_on_datapoint
(
self
,
ident
):
client
=
self
.
clients
[
ident
]
if
len
(
client
.
memory
)
==
LOCAL_TIME_MAX
+
1
:
R
=
client
.
memory
[
-
1
]
.
value
self
.
_parse_memory
(
R
,
ident
,
False
)
def
_parse_memory
(
self
,
init_r
,
ident
,
isOver
):
client
=
self
.
clients
[
ident
]
def
_process_msg
(
self
,
client
,
state
,
reward
,
isOver
):
"""
Process a message sent from some client.
"""
# in the first message, only state is valid,
# reward&isOver should be discarded
if
len
(
client
.
memory
)
>
0
:
client
.
memory
[
-
1
]
.
reward
=
reward
if
isOver
:
# should clear client's memory and put to queue
self
.
_parse_memory
(
0
,
client
,
True
)
else
:
if
len
(
client
.
memory
)
==
LOCAL_TIME_MAX
+
1
:
R
=
client
.
memory
[
-
1
]
.
value
self
.
_parse_memory
(
R
,
client
,
False
)
# feed state and return action
self
.
_on_state
(
state
,
client
)
def
_parse_memory
(
self
,
init_r
,
client
,
isOver
):
mem
=
client
.
memory
if
not
isOver
:
last
=
mem
[
-
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment