Commit c83f2d9f authored by Yuxin Wu's avatar Yuxin Wu

a different simulator framework

parent 9d3cf419
...@@ -9,9 +9,11 @@ import threading ...@@ -9,9 +9,11 @@ import threading
import weakref import weakref
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from six.moves import queue
from tensorpack.utils.serialize import * from ..utils.timer import *
from tensorpack.utils.concurrency import * from ..utils.serialize import *
from ..utils.concurrency import *
__all__ = ['SimulatorProcess', 'SimulatorMaster'] __all__ = ['SimulatorProcess', 'SimulatorMaster']
...@@ -26,30 +28,40 @@ class SimulatorProcess(multiprocessing.Process): ...@@ -26,30 +28,40 @@ class SimulatorProcess(multiprocessing.Process):
""" A process that simulates a player """ """ A process that simulates a player """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
def __init__(self, idx, server_name): def __init__(self, idx, pipe_c2s, pipe_s2c):
""" """
:param idx: idx of this process :param idx: idx of this process
:param player: An RLEnvironment
:param server_name: name of the server socket
""" """
super(SimulatorProcess, self).__init__() super(SimulatorProcess, self).__init__()
self.idx = int(idx) self.idx = int(idx)
self.server_name = server_name self.c2s = pipe_c2s
self.s2c = pipe_s2c
def run(self): def run(self):
player = self._build_player() player = self._build_player()
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REQ) c2s_socket = context.socket(zmq.DEALER)
socket.identity = 'simulator-{}'.format(self.idx) c2s_socket.identity = 'simulator-{}'.format(self.idx)
socket.connect(self.server_name) #c2s_socket.set_hwm(2)
c2s_socket.connect(self.c2s)
s2c_socket = context.socket(zmq.DEALER)
s2c_socket.identity = 'simulator-{}'.format(self.idx)
#s2c_socket.set_hwm(5)
s2c_socket.connect(self.s2c)
#cnt = 0
while True: while True:
state = player.current_state() state = player.current_state()
socket.send(dumps(state), copy=False) c2s_socket.send(dumps(state), copy=False)
action = loads(socket.recv(copy=False)) #with total_timer('client recv_action'):
data = s2c_socket.recv(copy=False)
action = loads(data)
reward, isOver = player.action(action) reward, isOver = player.action(action)
socket.send(dumps((reward, isOver)), copy=False) c2s_socket.send(dumps((reward, isOver)), copy=False)
noop = socket.recv(copy=False) #cnt += 1
#if cnt % 100 == 0:
#print_total_timer()
@abstractmethod @abstractmethod
def _build_player(self): def _build_player(self):
...@@ -76,33 +88,30 @@ class SimulatorMaster(threading.Thread): ...@@ -76,33 +88,30 @@ class SimulatorMaster(threading.Thread):
self.reward = reward self.reward = reward
self.misc = misc self.misc = misc
def __init__(self, server_name): def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__() super(SimulatorMaster, self).__init__()
self.server_name = server_name
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(self.server_name) self.c2s_socket = self.context.socket(zmq.ROUTER)
self.c2s_socket.bind(pipe_c2s)
self.s2c_socket = self.context.socket(zmq.ROUTER)
self.s2c_socket.bind(pipe_s2c)
self.socket_lock = threading.Lock() self.socket_lock = threading.Lock()
self.daemon = True self.daemon = True
def clean_context(sok, context): def clean_context(soks, context):
sok.close() for s in soks:
s.close()
context.term() context.term()
import atexit import atexit
atexit.register(clean_context, self.socket, self.context) atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)
def run(self): def run(self):
self.clients = defaultdict(SimulatorMaster.ClientState) self.clients = defaultdict(SimulatorMaster.ClientState)
while True: while True:
while True: ident, msg = self.c2s_socket.recv_multipart()
# avoid the lock being acquired here forever
try:
with self.socket_lock:
ident, _, msg = self.socket.recv_multipart(zmq.NOBLOCK)
break
except zmq.ZMQError:
#pass
time.sleep(0.001)
#assert _ == "" #assert _ == ""
client = self.clients[ident] client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state client.protocol_state = 1 - client.protocol_state # first flip the state
...@@ -116,11 +125,6 @@ class SimulatorMaster(threading.Thread): ...@@ -116,11 +125,6 @@ class SimulatorMaster(threading.Thread):
self._on_episode_over(client) self._on_episode_over(client)
else: else:
self._on_datapoint(client) self._on_datapoint(client)
self.send_multipart_threadsafe([ident, _, dumps('Thanks')])
def send_multipart_threadsafe(self, data):
with self.socket_lock:
self.socket.send_multipart(data)
@abstractmethod @abstractmethod
def _on_state(self, state, ident): def _on_state(self, state, ident):
......
...@@ -13,6 +13,7 @@ from six.moves import queue, range, zip ...@@ -13,6 +13,7 @@ from six.moves import queue, range, zip
from ..utils.concurrency import DIE from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..utils import logger from ..utils import logger
from ..utils.timer import *
from ..tfutils import * from ..tfutils import *
from .common import * from .common import *
...@@ -97,12 +98,14 @@ class PredictorWorkerThread(threading.Thread): ...@@ -97,12 +98,14 @@ class PredictorWorkerThread(threading.Thread):
inp, f = self.queue.get() inp, f = self.queue.get()
batched.append(inp) batched.append(inp)
futures.append(f) futures.append(f)
#print "func queue:", self.queue.qsize()
#return batched, futures
while True: while True:
try: try:
inp, f = self.queue.get_nowait() inp, f = self.queue.get_nowait()
batched.append(inp) batched.append(inp)
futures.append(f) futures.append(f)
if len(batched) == 128: if len(batched) == 5:
break break
except queue.Empty: except queue.Empty:
break break
...@@ -137,7 +140,7 @@ class MultiThreadAsyncPredictor(object): ...@@ -137,7 +140,7 @@ class MultiThreadAsyncPredictor(object):
""" """
:param trainer: a `QueueInputTrainer` instance. :param trainer: a `QueueInputTrainer` instance.
""" """
self.input_queue = queue.Queue(maxsize=nr_thread*2) self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.threads = [ self.threads = [
PredictorWorkerThread(self.input_queue, f, id) PredictorWorkerThread(self.input_queue, f, id)
for id, f in enumerate( for id, f in enumerate(
......
...@@ -15,6 +15,7 @@ class StatCounter(object): ...@@ -15,6 +15,7 @@ class StatCounter(object):
def reset(self): def reset(self):
self.values = [] self.values = []
@property
def count(self): def count(self):
return len(self.values) return len(self.values)
......
...@@ -8,6 +8,7 @@ from contextlib import contextmanager ...@@ -8,6 +8,7 @@ from contextlib import contextmanager
import time import time
from collections import defaultdict from collections import defaultdict
import six import six
import atexit
from .stat import StatCounter from .stat import StatCounter
from . import logger from . import logger
...@@ -33,5 +34,10 @@ def total_timer(msg): ...@@ -33,5 +34,10 @@ def total_timer(msg):
_TOTAL_TIMER_DATA[msg].feed(t) _TOTAL_TIMER_DATA[msg].feed(t)
def print_total_timer(): def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0:
return
for k, v in six.iteritems(_TOTAL_TIMER_DATA): for k, v in six.iteritems(_TOTAL_TIMER_DATA):
logger.info("Total Time: {} -> {} sec".format(k, v.sum)) logger.info("Total Time: {} -> {} sec, {} times".format(
k, v.sum, v.count))
atexit.register(print_total_timer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment