Commit df82c65a authored by Yuxin Wu's avatar Yuxin Wu

[A3C] code simplification

parent d451368a
...@@ -18,7 +18,6 @@ from tensorpack.utils.concurrency import LoopThread, enable_death_signal, ensure ...@@ -18,7 +18,6 @@ from tensorpack.utils.concurrency import LoopThread, enable_death_signal, ensure
from tensorpack.utils.serialize import dumps, loads from tensorpack.utils.serialize import dumps, loads
__all__ = ['SimulatorProcess', 'SimulatorMaster', __all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange',
'TransitionExperience'] 'TransitionExperience']
...@@ -35,19 +34,7 @@ class TransitionExperience(object): ...@@ -35,19 +34,7 @@ class TransitionExperience(object):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class SimulatorProcessBase(mp.Process): class SimulatorProcess(mp.Process):
def __init__(self, idx):
super(SimulatorProcessBase, self).__init__()
self.idx = int(idx)
self.name = u'simulator-{}'.format(self.idx)
self.identity = self.name.encode('utf-8')
@abstractmethod
def _build_player(self):
pass
class SimulatorProcessStateExchange(SimulatorProcessBase):
""" """
A process that simulates a player and communicates to master to A process that simulates a player and communicates to master to
send states and receive the next action send states and receive the next action
...@@ -59,7 +46,11 @@ class SimulatorProcessStateExchange(SimulatorProcessBase): ...@@ -59,7 +46,11 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
idx: idx of this process idx: idx of this process
pipe_c2s, pipe_s2c (str): name of the pipe pipe_c2s, pipe_s2c (str): name of the pipe
""" """
super(SimulatorProcessStateExchange, self).__init__(idx) super(SimulatorProcess, self).__init__()
self.idx = int(idx)
self.name = u'simulator-{}'.format(self.idx)
self.identity = self.name.encode('utf-8')
self.c2s = pipe_c2s self.c2s = pipe_c2s
self.s2c = pipe_s2c self.s2c = pipe_s2c
...@@ -90,13 +81,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase): ...@@ -90,13 +81,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
if isOver: if isOver:
state = player.reset() state = player.reset()
@abstractmethod
# compatibility def _build_player(self):
SimulatorProcess = SimulatorProcessStateExchange pass
@six.add_metaclass(ABCMeta)
class SimulatorMaster(threading.Thread): class SimulatorMaster(threading.Thread):
""" A base thread to communicate with all StateExchangeSimulatorProcess. """ A base thread to communicate with all SimulatorProcess.
It should produce action for each simulator, as well as It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished. defining callbacks when a transition or an episode is finished.
""" """
...@@ -106,6 +98,10 @@ class SimulatorMaster(threading.Thread): ...@@ -106,6 +98,10 @@ class SimulatorMaster(threading.Thread):
self.ident = None self.ident = None
def __init__(self, pipe_c2s, pipe_s2c): def __init__(self, pipe_c2s, pipe_s2c):
"""
Args:
pipe_c2s, pipe_s2c (str): names of pipe to be used for communication
"""
super(SimulatorMaster, self).__init__() super(SimulatorMaster, self).__init__()
assert os.name != 'nt', "Doesn't support windows!" assert os.name != 'nt', "Doesn't support windows!"
self.daemon = True self.daemon = True
...@@ -152,6 +148,10 @@ class SimulatorMaster(threading.Thread): ...@@ -152,6 +148,10 @@ class SimulatorMaster(threading.Thread):
except zmq.ContextTerminated: except zmq.ContextTerminated:
logger.info("[Simulator] Context was terminated.") logger.info("[Simulator] Context was terminated.")
@abstractmethod
def _process_msg(self, client, state, reward, isOver):
pass
def __del__(self): def __del__(self):
self.context.destroy(linger=0) self.context.destroy(linger=0)
......
...@@ -139,12 +139,16 @@ class Model(ModelDesc): ...@@ -139,12 +139,16 @@ class Model(ModelDesc):
class MySimulatorMaster(SimulatorMaster, Callback): class MySimulatorMaster(SimulatorMaster, Callback):
def __init__(self, pipe_c2s, pipe_s2c, gpus): def __init__(self, pipe_c2s, pipe_s2c, gpus):
"""
Args:
gpus (list[int]): the gpus used to run inference
"""
super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c) super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2) self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
self._gpus = gpus self._gpus = gpus
def _setup_graph(self): def _setup_graph(self):
# create predictors on the available predictor GPUs. # Create predictors on the available predictor GPUs.
num_gpu = len(self._gpus) num_gpu = len(self._gpus)
predictors = [self.trainer.get_predictor( predictors = [self.trainer.get_predictor(
['state'], ['policy', 'pred_value'], ['state'], ['policy', 'pred_value'],
...@@ -155,6 +159,8 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -155,6 +159,8 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _before_train(self): def _before_train(self):
self.async_predictor.start() self.async_predictor.start()
logger.info("Starting MySimulatorMaster ...")
start_proc_mask_signal(self)
def _on_state(self, state, client): def _on_state(self, state, client):
""" """
...@@ -208,6 +214,10 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -208,6 +214,10 @@ class MySimulatorMaster(SimulatorMaster, Callback):
else: else:
client.memory = [] client.memory = []
def get_training_dataflow(self):
# the queue contains batched experience
return BatchData(DataFromQueue(self.queue), BATCH_SIZE)
def train(): def train():
assert tf.test.is_gpu_available(), "Training requires GPUs!" assert tf.test.is_gpu_available(), "Training requires GPUs!"
...@@ -242,24 +252,19 @@ def train(): ...@@ -242,24 +252,19 @@ def train():
start_proc_mask_signal(procs) start_proc_mask_signal(procs)
master = MySimulatorMaster(namec2s, names2c, predict_tower) master = MySimulatorMaster(namec2s, names2c, predict_tower)
dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)
config = TrainConfig( config = TrainConfig(
model=Model(), model=Model(),
dataflow=dataflow, dataflow=master.get_training_dataflow(),
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]), ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]), ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
HumanHyperParamSetter('learning_rate'),
HumanHyperParamSetter('entropy_beta'),
master, master,
StartProcOrThread(master),
PeriodicTrigger(Evaluator( PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['policy'], get_player), EVAL_EPISODE, ['state'], ['policy'], get_player),
every_k_epochs=3), every_k_epochs=3),
], ],
session_creator=sesscreate.NewSessionCreator( session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)),
config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
session_init=get_model_loader(args.load) if args.load else None, session_init=get_model_loader(args.load) if args.load else None,
max_epoch=1000, max_epoch=1000,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment