Commit cc844ed4 authored by Yuxin Wu's avatar Yuxin Wu

use async callback in simulator

parent 40cee0cb
......@@ -61,12 +61,25 @@ class SimulatorMaster(threading.Thread):
"""
__metaclass__ = ABCMeta
class ClientState(object):
def __init__(self):
self.protocol_state = 0 # state in communication
self.memory = [] # list of Experience
class Experience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward):
self.state = state
self.action = action
self.reward = reward
def __init__(self, server_name):
super(SimulatorMaster, self).__init__()
self.server_name = server_name
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.bind(self.server_name)
self.socket_lock = threading.Lock()
self.daemon = True
def clean_context(sok, context):
......@@ -76,41 +89,31 @@ class SimulatorMaster(threading.Thread):
atexit.register(clean_context, self.socket, self.context)
def run(self):
class ClientState(object):
def __init__(self):
self.protocol_state = 0 # state in communication
self.memory = [] # list of Experience
class Experience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward):
self.state = state
self.action = action
self.reward = reward
self.clients = defaultdict(ClientState)
self.clients = defaultdict(SimulatorMaster.ClientState)
while True:
ident, _, msg = self.socket.recv_multipart()
#assert _ == ""
client = self.clients[ident]
if client.protocol_state == 0: # state-action
client.protocol_state = 1 - client.protocol_state # first flip the state
if not client.protocol_state == 0: # state-action
state = loads(msg)
action = self._get_action(state)
self.socket.send_multipart([ident, _, dumps(action)])
client.memory.append(Experience(state, action, None))
self._on_state(state, ident)
else: # reward-response
reward, isOver = loads(msg)
assert isinstance(isOver, bool)
client.memory[-1].reward = reward
if isOver:
self._on_episode_over(client)
else:
self._on_datapoint(client)
self.socket.send_multipart([ident, _, dumps('Thanks')])
client.protocol_state = 1 - client.protocol_state # flip the state
self.send_multipart_threadsafe([ident, _, dumps('Thanks')])
def send_multipart_threadsafe(self, data):
with self.socket_lock:
self.socket.send_multipart(data)
@abstractmethod
def _get_action(self, state):
"""response to state"""
def _on_state(self, state, ident):
"""response to state sent by ident. Preferrably an async call"""
@abstractmethod
def _on_episode_over(self, client):
......
......@@ -19,6 +19,8 @@ from .common import *
try:
if six.PY2:
from tornado.concurrent import Future
import tornado.options as options
options.parse_command_line(['--logging=debug'])
else:
from concurrent.futures import Future
except ImportError:
......@@ -78,12 +80,13 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else:
self.outqueue.put((tid, self.func(dp)))
class PerdictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func):
class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id):
super(PredictorWorkerThread, self).__init__()
self.queue = queue
self.func = pred_func
self.daemon = True
self.id = id
def run(self):
while True:
......@@ -101,8 +104,11 @@ class MultiThreadAsyncPredictor(object):
:param trainer: a `QueueInputTrainer` instance.
"""
self.input_queue = queue.Queue(maxsize=nr_thread*2)
self.threads = [PredictorWorkerThread(self.input_queue, f)
for f in trainer.get_predict_funcs(input_names, output_names, nr_thread)]
self.threads = [
PredictorWorkerThread(self.input_queue, f, id)
for id, f in enumerate(
trainer.get_predict_funcs(
input_names, output_names, nr_thread))]
def run(self):
for t in self.threads:
......
......@@ -265,7 +265,7 @@ class QueueInputTrainer(Trainer):
return func
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_name, output_names, k)
return [self.get_predict_func(input_names, output_names, k)
for k in range(n)]
def start_train(config):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment