Commit e2194663 authored by Yuxin Wu's avatar Yuxin Wu

[A3C] Simplify simulator master

parent 29a7da44
...@@ -103,6 +103,7 @@ class SimulatorMaster(threading.Thread): ...@@ -103,6 +103,7 @@ class SimulatorMaster(threading.Thread):
class ClientState(object): class ClientState(object):
def __init__(self): def __init__(self):
self.memory = [] # list of Experience self.memory = [] # list of Experience
self.ident = None
def __init__(self, pipe_c2s, pipe_s2c): def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__() super(SimulatorMaster, self).__init__()
...@@ -143,36 +144,14 @@ class SimulatorMaster(threading.Thread): ...@@ -143,36 +144,14 @@ class SimulatorMaster(threading.Thread):
while True: while True:
msg = loads(self.c2s_socket.recv(copy=False).bytes) msg = loads(self.c2s_socket.recv(copy=False).bytes)
ident, state, reward, isOver = msg ident, state, reward, isOver = msg
# TODO check history and warn about dead client
client = self.clients[ident] client = self.clients[ident]
if client.ident is None:
# check if reward&isOver is valid client.ident = ident
# in the first message, only state is valid # maybe check history and warn about dead client?
if len(client.memory) > 0: self._process_msg(client, state, reward, isOver)
client.memory[-1].reward = reward
if isOver:
self._on_episode_over(ident)
else:
self._on_datapoint(ident)
# feed state and return action
self._on_state(state, ident)
except zmq.ContextTerminated: except zmq.ContextTerminated:
logger.info("[Simulator] Context was terminated.") logger.info("[Simulator] Context was terminated.")
@abstractmethod
def _on_state(self, state, ident):
"""response to state sent by ident. Preferrably an async call"""
@abstractmethod
def _on_episode_over(self, client):
""" callback when the client just finished an episode.
You may want to clear the client's memory in this callback.
"""
def _on_datapoint(self, client):
""" callback when the client just finished a transition
"""
def __del__(self): def __del__(self):
self.context.destroy(linger=0) self.context.destroy(linger=0)
......
...@@ -158,32 +158,42 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -158,32 +158,42 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _before_train(self): def _before_train(self):
self.async_predictor.start() self.async_predictor.start()
def _on_state(self, state, ident): def _on_state(self, state, client):
"""
Launch forward prediction for the new state given by some client.
"""
def cb(outputs): def cb(outputs):
try: try:
distrib, value = outputs.result() distrib, value = outputs.result()
except CancelledError: except CancelledError:
logger.info("Client {} cancelled.".format(ident)) logger.info("Client {} cancelled.".format(client.ident))
return return
assert np.all(np.isfinite(distrib)), distrib assert np.all(np.isfinite(distrib)), distrib
action = np.random.choice(len(distrib), p=distrib) action = np.random.choice(len(distrib), p=distrib)
client = self.clients[ident]
client.memory.append(TransitionExperience( client.memory.append(TransitionExperience(
state, action, reward=None, value=value, prob=distrib[action])) state, action, reward=None, value=value, prob=distrib[action]))
self.send_queue.put([ident, dumps(action)]) self.send_queue.put([client.ident, dumps(action)])
self.async_predictor.put_task([state], cb) self.async_predictor.put_task([state], cb)
def _on_episode_over(self, ident): def _process_msg(self, client, state, reward, isOver):
self._parse_memory(0, ident, True) """
Process a message sent from some client.
def _on_datapoint(self, ident): """
client = self.clients[ident] # in the first message, only state is valid,
if len(client.memory) == LOCAL_TIME_MAX + 1: # reward&isOver should be discarded
R = client.memory[-1].value if len(client.memory) > 0:
self._parse_memory(R, ident, False) client.memory[-1].reward = reward
if isOver:
def _parse_memory(self, init_r, ident, isOver): # should clear client's memory and put to queue
client = self.clients[ident] self._parse_memory(0, client, True)
else:
if len(client.memory) == LOCAL_TIME_MAX + 1:
R = client.memory[-1].value
self._parse_memory(R, client, False)
# feed state and return action
self._on_state(state, client)
def _parse_memory(self, init_r, client, isOver):
mem = client.memory mem = client.memory
if not isOver: if not isOver:
last = mem[-1] last = mem[-1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment