Commit e2194663 authored by Yuxin Wu's avatar Yuxin Wu

[A3C] Simplify simulator master

parent 29a7da44
......@@ -103,6 +103,7 @@ class SimulatorMaster(threading.Thread):
class ClientState(object):
def __init__(self):
self.memory = [] # list of Experience
self.ident = None
def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__()
......@@ -143,36 +144,14 @@ class SimulatorMaster(threading.Thread):
while True:
msg = loads(self.c2s_socket.recv(copy=False).bytes)
ident, state, reward, isOver = msg
# TODO check history and warn about dead client
client = self.clients[ident]
# check if reward&isOver is valid
# in the first message, only state is valid
if len(client.memory) > 0:
client.memory[-1].reward = reward
if isOver:
self._on_episode_over(ident)
else:
self._on_datapoint(ident)
# feed state and return action
self._on_state(state, ident)
if client.ident is None:
client.ident = ident
# maybe check history and warn about dead client?
self._process_msg(client, state, reward, isOver)
except zmq.ContextTerminated:
logger.info("[Simulator] Context was terminated.")
@abstractmethod
def _on_state(self, state, ident):
"""response to state sent by ident. Preferrably an async call"""
@abstractmethod
def _on_episode_over(self, client):
""" callback when the client just finished an episode.
You may want to clear the client's memory in this callback.
"""
def _on_datapoint(self, client):
""" callback when the client just finished a transition
"""
def __del__(self):
self.context.destroy(linger=0)
......
......@@ -158,32 +158,42 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _before_train(self):
self.async_predictor.start()
def _on_state(self, state, ident):
def _on_state(self, state, client):
"""
Launch forward prediction for the new state given by some client.
"""
def cb(outputs):
try:
distrib, value = outputs.result()
except CancelledError:
logger.info("Client {} cancelled.".format(ident))
logger.info("Client {} cancelled.".format(client.ident))
return
assert np.all(np.isfinite(distrib)), distrib
action = np.random.choice(len(distrib), p=distrib)
client = self.clients[ident]
client.memory.append(TransitionExperience(
state, action, reward=None, value=value, prob=distrib[action]))
self.send_queue.put([ident, dumps(action)])
self.send_queue.put([client.ident, dumps(action)])
self.async_predictor.put_task([state], cb)
def _on_episode_over(self, ident):
self._parse_memory(0, ident, True)
def _on_datapoint(self, ident):
client = self.clients[ident]
def _process_msg(self, client, state, reward, isOver):
"""
Process a message sent from some client.
"""
# in the first message, only state is valid,
# reward&isOver should be discarded
if len(client.memory) > 0:
client.memory[-1].reward = reward
if isOver:
# should clear client's memory and put to queue
self._parse_memory(0, client, True)
else:
if len(client.memory) == LOCAL_TIME_MAX + 1:
R = client.memory[-1].value
self._parse_memory(R, ident, False)
self._parse_memory(R, client, False)
# feed state and return action
self._on_state(state, client)
def _parse_memory(self, init_r, ident, isOver):
client = self.clients[ident]
def _parse_memory(self, init_r, client, isOver):
mem = client.memory
if not isOver:
last = mem[-1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment