Commit df82c65a authored by Yuxin Wu's avatar Yuxin Wu

[A3C] code simplification

parent d451368a
......@@ -18,7 +18,6 @@ from tensorpack.utils.concurrency import LoopThread, enable_death_signal, ensure
from tensorpack.utils.serialize import dumps, loads
__all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange',
'TransitionExperience']
......@@ -35,19 +34,7 @@ class TransitionExperience(object):
@six.add_metaclass(ABCMeta)
class SimulatorProcessBase(mp.Process):
def __init__(self, idx):
super(SimulatorProcessBase, self).__init__()
self.idx = int(idx)
self.name = u'simulator-{}'.format(self.idx)
self.identity = self.name.encode('utf-8')
@abstractmethod
def _build_player(self):
pass
class SimulatorProcessStateExchange(SimulatorProcessBase):
class SimulatorProcess(mp.Process):
"""
A process that simulates a player and communicates to master to
send states and receive the next action
......@@ -59,7 +46,11 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
idx: idx of this process
pipe_c2s, pipe_s2c (str): name of the pipe
"""
super(SimulatorProcessStateExchange, self).__init__(idx)
super(SimulatorProcess, self).__init__()
self.idx = int(idx)
self.name = u'simulator-{}'.format(self.idx)
self.identity = self.name.encode('utf-8')
self.c2s = pipe_c2s
self.s2c = pipe_s2c
......@@ -90,13 +81,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
if isOver:
state = player.reset()
# compatibility
SimulatorProcess = SimulatorProcessStateExchange
@abstractmethod
def _build_player(self):
pass
@six.add_metaclass(ABCMeta)
class SimulatorMaster(threading.Thread):
""" A base thread to communicate with all StateExchangeSimulatorProcess.
""" A base thread to communicate with all SimulatorProcess.
It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished.
"""
......@@ -106,6 +98,10 @@ class SimulatorMaster(threading.Thread):
self.ident = None
def __init__(self, pipe_c2s, pipe_s2c):
"""
Args:
pipe_c2s, pipe_s2c (str): names of pipe to be used for communication
"""
super(SimulatorMaster, self).__init__()
assert os.name != 'nt', "Doesn't support windows!"
self.daemon = True
......@@ -152,6 +148,10 @@ class SimulatorMaster(threading.Thread):
except zmq.ContextTerminated:
logger.info("[Simulator] Context was terminated.")
@abstractmethod
def _process_msg(self, client, state, reward, isOver):
pass
def __del__(self):
self.context.destroy(linger=0)
......
......@@ -139,12 +139,16 @@ class Model(ModelDesc):
class MySimulatorMaster(SimulatorMaster, Callback):
def __init__(self, pipe_c2s, pipe_s2c, gpus):
"""
Args:
gpus (list[int]): the gpus used to run inference
"""
super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
self._gpus = gpus
def _setup_graph(self):
# create predictors on the available predictor GPUs.
# Create predictors on the available predictor GPUs.
num_gpu = len(self._gpus)
predictors = [self.trainer.get_predictor(
['state'], ['policy', 'pred_value'],
......@@ -155,6 +159,8 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _before_train(self):
self.async_predictor.start()
logger.info("Starting MySimulatorMaster ...")
start_proc_mask_signal(self)
def _on_state(self, state, client):
"""
......@@ -208,6 +214,10 @@ class MySimulatorMaster(SimulatorMaster, Callback):
else:
client.memory = []
def get_training_dataflow(self):
# the queue contains batched experience
return BatchData(DataFromQueue(self.queue), BATCH_SIZE)
def train():
assert tf.test.is_gpu_available(), "Training requires GPUs!"
......@@ -242,24 +252,19 @@ def train():
start_proc_mask_signal(procs)
master = MySimulatorMaster(namec2s, names2c, predict_tower)
dataflow = BatchData(DataFromQueue(master.queue), BATCH_SIZE)
config = TrainConfig(
model=Model(),
dataflow=dataflow,
dataflow=master.get_training_dataflow(),
callbacks=[
ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
HumanHyperParamSetter('learning_rate'),
HumanHyperParamSetter('entropy_beta'),
master,
StartProcOrThread(master),
PeriodicTrigger(Evaluator(
EVAL_EPISODE, ['state'], ['policy'], get_player),
every_k_epochs=3),
],
session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)),
session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH,
session_init=get_model_loader(args.load) if args.load else None,
max_epoch=1000,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment