Commit cc844ed4 authored by Yuxin Wu's avatar Yuxin Wu

use async callback in simulator

parent 40cee0cb
...@@ -61,12 +61,25 @@ class SimulatorMaster(threading.Thread): ...@@ -61,12 +61,25 @@ class SimulatorMaster(threading.Thread):
""" """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
class ClientState(object):
def __init__(self):
self.protocol_state = 0 # state in communication
self.memory = [] # list of Experience
class Experience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward):
self.state = state
self.action = action
self.reward = reward
def __init__(self, server_name): def __init__(self, server_name):
super(SimulatorMaster, self).__init__() super(SimulatorMaster, self).__init__()
self.server_name = server_name self.server_name = server_name
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER) self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(self.server_name) self.socket.bind(self.server_name)
self.socket_lock = threading.Lock()
self.daemon = True self.daemon = True
def clean_context(sok, context): def clean_context(sok, context):
...@@ -76,41 +89,31 @@ class SimulatorMaster(threading.Thread): ...@@ -76,41 +89,31 @@ class SimulatorMaster(threading.Thread):
atexit.register(clean_context, self.socket, self.context) atexit.register(clean_context, self.socket, self.context)
def run(self): def run(self):
class ClientState(object): self.clients = defaultdict(SimulatorMaster.ClientState)
def __init__(self):
self.protocol_state = 0 # state in communication
self.memory = [] # list of Experience
class Experience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward):
self.state = state
self.action = action
self.reward = reward
self.clients = defaultdict(ClientState)
while True: while True:
ident, _, msg = self.socket.recv_multipart() ident, _, msg = self.socket.recv_multipart()
#assert _ == ""
client = self.clients[ident] client = self.clients[ident]
if client.protocol_state == 0: # state-action client.protocol_state = 1 - client.protocol_state # first flip the state
if not client.protocol_state == 0: # state-action
state = loads(msg) state = loads(msg)
action = self._get_action(state) self._on_state(state, ident)
self.socket.send_multipart([ident, _, dumps(action)])
client.memory.append(Experience(state, action, None))
else: # reward-response else: # reward-response
reward, isOver = loads(msg) reward, isOver = loads(msg)
assert isinstance(isOver, bool)
client.memory[-1].reward = reward client.memory[-1].reward = reward
if isOver: if isOver:
self._on_episode_over(client) self._on_episode_over(client)
else: else:
self._on_datapoint(client) self._on_datapoint(client)
self.socket.send_multipart([ident, _, dumps('Thanks')]) self.send_multipart_threadsafe([ident, _, dumps('Thanks')])
client.protocol_state = 1 - client.protocol_state # flip the state
def send_multipart_threadsafe(self, data):
with self.socket_lock:
self.socket.send_multipart(data)
@abstractmethod @abstractmethod
def _get_action(self, state): def _on_state(self, state, ident):
"""response to state""" """response to state sent by ident. Preferrably an async call"""
@abstractmethod @abstractmethod
def _on_episode_over(self, client): def _on_episode_over(self, client):
......
...@@ -19,6 +19,8 @@ from .common import * ...@@ -19,6 +19,8 @@ from .common import *
try: try:
if six.PY2: if six.PY2:
from tornado.concurrent import Future from tornado.concurrent import Future
import tornado.options as options
options.parse_command_line(['--logging=debug'])
else: else:
from concurrent.futures import Future from concurrent.futures import Future
except ImportError: except ImportError:
...@@ -78,12 +80,13 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -78,12 +80,13 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else: else:
self.outqueue.put((tid, self.func(dp))) self.outqueue.put((tid, self.func(dp)))
class PerdictorWorkerThread(threading.Thread): class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func): def __init__(self, queue, pred_func, id):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
self.queue = queue self.queue = queue
self.func = pred_func self.func = pred_func
self.daemon = True self.daemon = True
self.id = id
def run(self): def run(self):
while True: while True:
...@@ -101,8 +104,11 @@ class MultiThreadAsyncPredictor(object): ...@@ -101,8 +104,11 @@ class MultiThreadAsyncPredictor(object):
:param trainer: a `QueueInputTrainer` instance. :param trainer: a `QueueInputTrainer` instance.
""" """
self.input_queue = queue.Queue(maxsize=nr_thread*2) self.input_queue = queue.Queue(maxsize=nr_thread*2)
self.threads = [PredictorWorkerThread(self.input_queue, f) self.threads = [
for f in trainer.get_predict_funcs(input_names, output_names, nr_thread)] PredictorWorkerThread(self.input_queue, f, id)
for id, f in enumerate(
trainer.get_predict_funcs(
input_names, output_names, nr_thread))]
def run(self): def run(self):
for t in self.threads: for t in self.threads:
......
...@@ -265,7 +265,7 @@ class QueueInputTrainer(Trainer): ...@@ -265,7 +265,7 @@ class QueueInputTrainer(Trainer):
return func return func
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_name, output_names, k) return [self.get_predict_func(input_names, output_names, k)
for k in range(n)] for k in range(n)]
def start_train(config): def start_train(config):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment