Commit e665f053 authored by Shashank Suhas's avatar Shashank Suhas

Added reward logging

parent 4665d5f1
...@@ -151,6 +151,11 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -151,6 +151,11 @@ class MySimulatorMaster(SimulatorMaster, Callback):
super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c) super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2) self.queue = queue.Queue(maxsize=BATCH_SIZE * 8 * 2)
self._gpus = gpus self._gpus = gpus
self.reward = 0
self.fd = open('/kaggle/working', 'w')
def __del__(self):
self.f.close()
def _setup_graph(self): def _setup_graph(self):
# Create predictors on the available predictor GPUs. # Create predictors on the available predictor GPUs.
...@@ -195,8 +200,11 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -195,8 +200,11 @@ class MySimulatorMaster(SimulatorMaster, Callback):
client.memory[-1].reward = reward client.memory[-1].reward = reward
if isOver: if isOver:
# should clear client's memory and put to queue # should clear client's memory and put to queue
self.f.write(str(self.reward) + '\n')
self.reward = 0
self._parse_memory(0, client, True) self._parse_memory(0, client, True)
else: else:
self.reward += reward
if len(client.memory) == LOCAL_TIME_MAX + 1: if len(client.memory) == LOCAL_TIME_MAX + 1:
R = client.memory[-1].value R = client.memory[-1].value
self._parse_memory(R, client, False) self._parse_memory(R, client, False)
...@@ -269,7 +277,7 @@ def train(): ...@@ -269,7 +277,7 @@ def train():
session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)), session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
session_init=SmartInit(args.load), session_init=SmartInit(args.load),
max_epoch=1000, max_epoch=15,
) )
trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower) trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower)
launch_train_with_config(config, trainer) launch_train_with_config(config, trainer)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment