Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c83f2d9f
Commit
c83f2d9f
authored
Jun 03, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
a different simulator framework
parent
9d3cf419
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
37 deletions
+51
-37
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+38
-34
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+5
-2
tensorpack/utils/stat.py
tensorpack/utils/stat.py
+1
-0
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+7
-1
No files found.
tensorpack/RL/simulator.py
View file @
c83f2d9f
...
...
@@ -9,9 +9,11 @@ import threading
import
weakref
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
,
namedtuple
from
six.moves
import
queue
from
tensorpack.utils.serialize
import
*
from
tensorpack.utils.concurrency
import
*
from
..utils.timer
import
*
from
..utils.serialize
import
*
from
..utils.concurrency
import
*
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
]
...
...
@@ -26,30 +28,40 @@ class SimulatorProcess(multiprocessing.Process):
""" A process that simulates a player """
__metaclass__
=
ABCMeta
def
__init__
(
self
,
idx
,
server_name
):
def
__init__
(
self
,
idx
,
pipe_c2s
,
pipe_s2c
):
"""
:param idx: idx of this process
:param player: An RLEnvironment
:param server_name: name of the server socket
"""
super
(
SimulatorProcess
,
self
)
.
__init__
()
self
.
idx
=
int
(
idx
)
self
.
server_name
=
server_name
self
.
c2s
=
pipe_c2s
self
.
s2c
=
pipe_s2c
def
run
(
self
):
player
=
self
.
_build_player
()
context
=
zmq
.
Context
()
socket
=
context
.
socket
(
zmq
.
REQ
)
socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
socket
.
connect
(
self
.
server_name
)
c2s_socket
=
context
.
socket
(
zmq
.
DEALER
)
c2s_socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
#c2s_socket.set_hwm(2)
c2s_socket
.
connect
(
self
.
c2s
)
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
#s2c_socket.set_hwm(5)
s2c_socket
.
connect
(
self
.
s2c
)
#cnt = 0
while
True
:
state
=
player
.
current_state
()
socket
.
send
(
dumps
(
state
),
copy
=
False
)
action
=
loads
(
socket
.
recv
(
copy
=
False
))
c2s_socket
.
send
(
dumps
(
state
),
copy
=
False
)
#with total_timer('client recv_action'):
data
=
s2c_socket
.
recv
(
copy
=
False
)
action
=
loads
(
data
)
reward
,
isOver
=
player
.
action
(
action
)
socket
.
send
(
dumps
((
reward
,
isOver
)),
copy
=
False
)
noop
=
socket
.
recv
(
copy
=
False
)
c2s_socket
.
send
(
dumps
((
reward
,
isOver
)),
copy
=
False
)
#cnt += 1
#if cnt % 100 == 0:
#print_total_timer()
@
abstractmethod
def
_build_player
(
self
):
...
...
@@ -76,33 +88,30 @@ class SimulatorMaster(threading.Thread):
self
.
reward
=
reward
self
.
misc
=
misc
def
__init__
(
self
,
server_name
):
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
self
.
server_name
=
server_name
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
socket
.
bind
(
self
.
server_name
)
self
.
c2s_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
c2s_socket
.
bind
(
pipe_c2s
)
self
.
s2c_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
s2c_socket
.
bind
(
pipe_s2c
)
self
.
socket_lock
=
threading
.
Lock
()
self
.
daemon
=
True
def
clean_context
(
sok
,
context
):
sok
.
close
()
def
clean_context
(
soks
,
context
):
for
s
in
soks
:
s
.
close
()
context
.
term
()
import
atexit
atexit
.
register
(
clean_context
,
self
.
socket
,
self
.
context
)
atexit
.
register
(
clean_context
,
[
self
.
c2s_socket
,
self
.
s2c_socket
]
,
self
.
context
)
def
run
(
self
):
self
.
clients
=
defaultdict
(
SimulatorMaster
.
ClientState
)
while
True
:
while
True
:
# avoid the lock being acquired here forever
try
:
with
self
.
socket_lock
:
ident
,
_
,
msg
=
self
.
socket
.
recv_multipart
(
zmq
.
NOBLOCK
)
break
except
zmq
.
ZMQError
:
#pass
time
.
sleep
(
0.001
)
ident
,
msg
=
self
.
c2s_socket
.
recv_multipart
()
#assert _ == ""
client
=
self
.
clients
[
ident
]
client
.
protocol_state
=
1
-
client
.
protocol_state
# first flip the state
...
...
@@ -116,11 +125,6 @@ class SimulatorMaster(threading.Thread):
self
.
_on_episode_over
(
client
)
else
:
self
.
_on_datapoint
(
client
)
self
.
send_multipart_threadsafe
([
ident
,
_
,
dumps
(
'Thanks'
)])
def
send_multipart_threadsafe
(
self
,
data
):
with
self
.
socket_lock
:
self
.
socket
.
send_multipart
(
data
)
@
abstractmethod
def
_on_state
(
self
,
state
,
ident
):
...
...
tensorpack/predict/concurrency.py
View file @
c83f2d9f
...
...
@@ -13,6 +13,7 @@ from six.moves import queue, range, zip
from
..utils.concurrency
import
DIE
from
..tfutils.modelutils
import
describe_model
from
..utils
import
logger
from
..utils.timer
import
*
from
..tfutils
import
*
from
.common
import
*
...
...
@@ -97,12 +98,14 @@ class PredictorWorkerThread(threading.Thread):
inp
,
f
=
self
.
queue
.
get
()
batched
.
append
(
inp
)
futures
.
append
(
f
)
#print "func queue:", self.queue.qsize()
#return batched, futures
while
True
:
try
:
inp
,
f
=
self
.
queue
.
get_nowait
()
batched
.
append
(
inp
)
futures
.
append
(
f
)
if
len
(
batched
)
==
128
:
if
len
(
batched
)
==
5
:
break
except
queue
.
Empty
:
break
...
...
@@ -137,7 +140,7 @@ class MultiThreadAsyncPredictor(object):
"""
:param trainer: a `QueueInputTrainer` instance.
"""
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
nr_thread
*
2
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
nr_thread
*
10
)
self
.
threads
=
[
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
)
for
id
,
f
in
enumerate
(
...
...
tensorpack/utils/stat.py
View file @
c83f2d9f
...
...
@@ -15,6 +15,7 @@ class StatCounter(object):
def
reset
(
self
):
self
.
values
=
[]
@
property
def
count
(
self
):
return
len
(
self
.
values
)
...
...
tensorpack/utils/timer.py
View file @
c83f2d9f
...
...
@@ -8,6 +8,7 @@ from contextlib import contextmanager
import
time
from
collections
import
defaultdict
import
six
import
atexit
from
.stat
import
StatCounter
from
.
import
logger
...
...
@@ -33,5 +34,10 @@ def total_timer(msg):
_TOTAL_TIMER_DATA
[
msg
]
.
feed
(
t
)
def
print_total_timer
():
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
return
for
k
,
v
in
six
.
iteritems
(
_TOTAL_TIMER_DATA
):
logger
.
info
(
"Total Time: {} -> {} sec"
.
format
(
k
,
v
.
sum
))
logger
.
info
(
"Total Time: {} -> {} sec, {} times"
.
format
(
k
,
v
.
sum
,
v
.
count
))
atexit
.
register
(
print_total_timer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment