Commit c83f2d9f authored by Yuxin Wu's avatar Yuxin Wu

a different simulator framework

parent 9d3cf419
......@@ -9,9 +9,11 @@ import threading
import weakref
from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
from six.moves import queue
from tensorpack.utils.serialize import *
from tensorpack.utils.concurrency import *
from ..utils.timer import *
from ..utils.serialize import *
from ..utils.concurrency import *
__all__ = ['SimulatorProcess', 'SimulatorMaster']
......@@ -26,30 +28,40 @@ class SimulatorProcess(multiprocessing.Process):
""" A process that simulates a player """
__metaclass__ = ABCMeta
def __init__(self, idx, server_name):
def __init__(self, idx, pipe_c2s, pipe_s2c):
"""
:param idx: idx of this process
:param player: An RLEnvironment
:param server_name: name of the server socket
"""
super(SimulatorProcess, self).__init__()
self.idx = int(idx)
self.server_name = server_name
self.c2s = pipe_c2s
self.s2c = pipe_s2c
def run(self):
player = self._build_player()
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.identity = 'simulator-{}'.format(self.idx)
socket.connect(self.server_name)
c2s_socket = context.socket(zmq.DEALER)
c2s_socket.identity = 'simulator-{}'.format(self.idx)
#c2s_socket.set_hwm(2)
c2s_socket.connect(self.c2s)
s2c_socket = context.socket(zmq.DEALER)
s2c_socket.identity = 'simulator-{}'.format(self.idx)
#s2c_socket.set_hwm(5)
s2c_socket.connect(self.s2c)
#cnt = 0
while True:
state = player.current_state()
socket.send(dumps(state), copy=False)
action = loads(socket.recv(copy=False))
c2s_socket.send(dumps(state), copy=False)
#with total_timer('client recv_action'):
data = s2c_socket.recv(copy=False)
action = loads(data)
reward, isOver = player.action(action)
socket.send(dumps((reward, isOver)), copy=False)
noop = socket.recv(copy=False)
c2s_socket.send(dumps((reward, isOver)), copy=False)
#cnt += 1
#if cnt % 100 == 0:
#print_total_timer()
@abstractmethod
def _build_player(self):
......@@ -76,33 +88,30 @@ class SimulatorMaster(threading.Thread):
self.reward = reward
self.misc = misc
def __init__(self, server_name):
def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__()
self.server_name = server_name
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(self.server_name)
self.c2s_socket = self.context.socket(zmq.ROUTER)
self.c2s_socket.bind(pipe_c2s)
self.s2c_socket = self.context.socket(zmq.ROUTER)
self.s2c_socket.bind(pipe_s2c)
self.socket_lock = threading.Lock()
self.daemon = True
def clean_context(sok, context):
sok.close()
def clean_context(soks, context):
for s in soks:
s.close()
context.term()
import atexit
atexit.register(clean_context, self.socket, self.context)
atexit.register(clean_context, [self.c2s_socket, self.s2c_socket], self.context)
def run(self):
self.clients = defaultdict(SimulatorMaster.ClientState)
while True:
while True:
# avoid the lock being acquired here forever
try:
with self.socket_lock:
ident, _, msg = self.socket.recv_multipart(zmq.NOBLOCK)
break
except zmq.ZMQError:
#pass
time.sleep(0.001)
ident, msg = self.c2s_socket.recv_multipart()
#assert _ == ""
client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state
......@@ -116,11 +125,6 @@ class SimulatorMaster(threading.Thread):
self._on_episode_over(client)
else:
self._on_datapoint(client)
self.send_multipart_threadsafe([ident, _, dumps('Thanks')])
def send_multipart_threadsafe(self, data):
with self.socket_lock:
self.socket.send_multipart(data)
@abstractmethod
def _on_state(self, state, ident):
......
......@@ -13,6 +13,7 @@ from six.moves import queue, range, zip
from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model
from ..utils import logger
from ..utils.timer import *
from ..tfutils import *
from .common import *
......@@ -97,12 +98,14 @@ class PredictorWorkerThread(threading.Thread):
inp, f = self.queue.get()
batched.append(inp)
futures.append(f)
#print "func queue:", self.queue.qsize()
#return batched, futures
while True:
try:
inp, f = self.queue.get_nowait()
batched.append(inp)
futures.append(f)
if len(batched) == 128:
if len(batched) == 5:
break
except queue.Empty:
break
......@@ -137,7 +140,7 @@ class MultiThreadAsyncPredictor(object):
"""
:param trainer: a `QueueInputTrainer` instance.
"""
self.input_queue = queue.Queue(maxsize=nr_thread*2)
self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.threads = [
PredictorWorkerThread(self.input_queue, f, id)
for id, f in enumerate(
......
......@@ -15,6 +15,7 @@ class StatCounter(object):
def reset(self):
self.values = []
@property
def count(self):
return len(self.values)
......
......@@ -8,6 +8,7 @@ from contextlib import contextmanager
import time
from collections import defaultdict
import six
import atexit
from .stat import StatCounter
from . import logger
......@@ -33,5 +34,10 @@ def total_timer(msg):
_TOTAL_TIMER_DATA[msg].feed(t)
def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0:
return
for k, v in six.iteritems(_TOTAL_TIMER_DATA):
logger.info("Total Time: {} -> {} sec".format(k, v.sum))
logger.info("Total Time: {} -> {} sec, {} times".format(
k, v.sum, v.count))
atexit.register(print_total_timer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment