Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
cc844ed4
Commit
cc844ed4
authored
Jun 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use async callback in simulator
parent
40cee0cb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
27 deletions
+36
-27
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+25
-22
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+10
-4
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
No files found.
tensorpack/RL/simulator.py
View file @
cc844ed4
...
...
@@ -61,12 +61,25 @@ class SimulatorMaster(threading.Thread):
"""
__metaclass__
=
ABCMeta
class
ClientState
(
object
):
def
__init__
(
self
):
self
.
protocol_state
=
0
# state in communication
self
.
memory
=
[]
# list of Experience
class
Experience
(
object
):
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
):
self
.
state
=
state
self
.
action
=
action
self
.
reward
=
reward
def
__init__
(
self
,
server_name
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
self
.
server_name
=
server_name
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
socket
.
bind
(
self
.
server_name
)
self
.
socket_lock
=
threading
.
Lock
()
self
.
daemon
=
True
def
clean_context
(
sok
,
context
):
...
...
@@ -76,41 +89,31 @@ class SimulatorMaster(threading.Thread):
atexit
.
register
(
clean_context
,
self
.
socket
,
self
.
context
)
def
run
(
self
):
class
ClientState
(
object
):
def
__init__
(
self
):
self
.
protocol_state
=
0
# state in communication
self
.
memory
=
[]
# list of Experience
class
Experience
(
object
):
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
):
self
.
state
=
state
self
.
action
=
action
self
.
reward
=
reward
self
.
clients
=
defaultdict
(
ClientState
)
self
.
clients
=
defaultdict
(
SimulatorMaster
.
ClientState
)
while
True
:
ident
,
_
,
msg
=
self
.
socket
.
recv_multipart
()
#assert _ == ""
client
=
self
.
clients
[
ident
]
if
client
.
protocol_state
==
0
:
# state-action
client
.
protocol_state
=
1
-
client
.
protocol_state
# first flip the state
if
not
client
.
protocol_state
==
0
:
# state-action
state
=
loads
(
msg
)
action
=
self
.
_get_action
(
state
)
self
.
socket
.
send_multipart
([
ident
,
_
,
dumps
(
action
)])
client
.
memory
.
append
(
Experience
(
state
,
action
,
None
))
self
.
_on_state
(
state
,
ident
)
else
:
# reward-response
reward
,
isOver
=
loads
(
msg
)
assert
isinstance
(
isOver
,
bool
)
client
.
memory
[
-
1
]
.
reward
=
reward
if
isOver
:
self
.
_on_episode_over
(
client
)
else
:
self
.
_on_datapoint
(
client
)
self
.
socket
.
send_multipart
([
ident
,
_
,
dumps
(
'Thanks'
)])
client
.
protocol_state
=
1
-
client
.
protocol_state
# flip the state
self
.
send_multipart_threadsafe
([
ident
,
_
,
dumps
(
'Thanks'
)])
def
send_multipart_threadsafe
(
self
,
data
):
with
self
.
socket_lock
:
self
.
socket
.
send_multipart
(
data
)
@
abstractmethod
def
_
get_action
(
self
,
state
):
"""response to state"""
def
_
on_state
(
self
,
state
,
ident
):
"""response to state
sent by ident. Preferrably an async call
"""
@
abstractmethod
def
_on_episode_over
(
self
,
client
):
...
...
tensorpack/predict/concurrency.py
View file @
cc844ed4
...
...
@@ -19,6 +19,8 @@ from .common import *
try
:
if
six
.
PY2
:
from
tornado.concurrent
import
Future
import
tornado.options
as
options
options
.
parse_command_line
([
'--logging=debug'
])
else
:
from
concurrent.futures
import
Future
except
ImportError
:
...
...
@@ -78,12 +80,13 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else
:
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
class
P
er
dictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
):
class
P
re
dictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
queue
=
queue
self
.
func
=
pred_func
self
.
daemon
=
True
self
.
id
=
id
def
run
(
self
):
while
True
:
...
...
@@ -101,8 +104,11 @@ class MultiThreadAsyncPredictor(object):
:param trainer: a `QueueInputTrainer` instance.
"""
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
nr_thread
*
2
)
self
.
threads
=
[
PredictorWorkerThread
(
self
.
input_queue
,
f
)
for
f
in
trainer
.
get_predict_funcs
(
input_names
,
output_names
,
nr_thread
)]
self
.
threads
=
[
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
)
for
id
,
f
in
enumerate
(
trainer
.
get_predict_funcs
(
input_names
,
output_names
,
nr_thread
))]
def
run
(
self
):
for
t
in
self
.
threads
:
...
...
tensorpack/train/trainer.py
View file @
cc844ed4
...
...
@@ -265,7 +265,7 @@ class QueueInputTrainer(Trainer):
return
func
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
[
self
.
get_predict_func
(
input_name
,
output_names
,
k
)
return
[
self
.
get_predict_func
(
input_name
s
,
output_names
,
k
)
for
k
in
range
(
n
)]
def
start_train
(
config
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment