Commit c5da59af authored by Yuxin Wu's avatar Yuxin Wu

minor fix for async

parent cc844ed4
...@@ -9,7 +9,7 @@ import os ...@@ -9,7 +9,7 @@ import os
import cv2 import cv2
from collections import deque from collections import deque
from six.moves import range from six.moves import range
from ..utils import get_rng, logger from ..utils import get_rng, logger, memoized
from ..utils.stat import StatCounter from ..utils.stat import StatCounter
from .envbase import RLEnvironment from .envbase import RLEnvironment
...@@ -21,6 +21,10 @@ except ImportError: ...@@ -21,6 +21,10 @@ except ImportError:
__all__ = ['AtariPlayer'] __all__ = ['AtariPlayer']
@memoized
def log_once():
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
class AtariPlayer(RLEnvironment): class AtariPlayer(RLEnvironment):
""" """
A wrapper for atari emulator. A wrapper for atari emulator.
...@@ -43,10 +47,12 @@ class AtariPlayer(RLEnvironment): ...@@ -43,10 +47,12 @@ class AtariPlayer(RLEnvironment):
self.ale.setInt("random_seed", self.rng.randint(0, 10000)) self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setBool("showinfo", False) self.ale.setBool("showinfo", False)
try: try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning) ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
except AttributeError: except AttributeError:
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!") log_once()
self.ale.setInt("frame_skip", 1) self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', False) self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check # manual.pdf suggests otherwise. may need to check
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing import multiprocessing
import time
import threading import threading
import weakref import weakref
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
...@@ -68,10 +69,12 @@ class SimulatorMaster(threading.Thread): ...@@ -68,10 +69,12 @@ class SimulatorMaster(threading.Thread):
class Experience(object): class Experience(object):
""" A transition of state, or experience""" """ A transition of state, or experience"""
def __init__(self, state, action, reward): def __init__(self, state, action, reward, misc=None):
""" misc: whatever other attribute you want to save"""
self.state = state self.state = state
self.action = action self.action = action
self.reward = reward self.reward = reward
self.misc = misc
def __init__(self, server_name): def __init__(self, server_name):
super(SimulatorMaster, self).__init__() super(SimulatorMaster, self).__init__()
...@@ -91,7 +94,14 @@ class SimulatorMaster(threading.Thread): ...@@ -91,7 +94,14 @@ class SimulatorMaster(threading.Thread):
def run(self): def run(self):
self.clients = defaultdict(SimulatorMaster.ClientState) self.clients = defaultdict(SimulatorMaster.ClientState)
while True: while True:
ident, _, msg = self.socket.recv_multipart() while True:
# avoid the lock being acquired here forever
try:
with self.socket_lock:
ident, _, msg = self.socket.recv_multipart(zmq.NOBLOCK)
break
except zmq.ZMQError:
time.sleep(0.01)
#assert _ == "" #assert _ == ""
client = self.clients[ident] client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state client.protocol_state = 1 - client.protocol_state # first flip the state
......
...@@ -27,16 +27,23 @@ class GradientProcessor(object): ...@@ -27,16 +27,23 @@ class GradientProcessor(object):
def _process(self, grads): def _process(self, grads):
pass pass
_summaried_gradient = set()
class SummaryGradient(GradientProcessor): class SummaryGradient(GradientProcessor):
""" """
Summary history and RMS for each graident variable Summary history and RMS for each graident variable
""" """
def _process(self, grads): def _process(self, grads):
for grad, var in grads: for grad, var in grads:
tf.histogram_summary(var.op.name + '/grad', grad) name = var.op.name
if name in _summaried_gradient:
continue
_summaried_gradient.add(name)
tf.histogram_summary(name + '/grad', grad)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, tf.add_to_collection(MOVING_SUMMARY_VARS_KEY,
tf.sqrt(tf.reduce_mean(tf.square(grad)), tf.sqrt(tf.reduce_mean(tf.square(grad)),
name=var.op.name + '/gradRMS')) name=name + '/gradRMS'))
return grads return grads
...@@ -46,7 +53,6 @@ class CheckGradient(GradientProcessor): ...@@ -46,7 +53,6 @@ class CheckGradient(GradientProcessor):
""" """
def _process(self, grads): def _process(self, grads):
for grad, var in grads: for grad, var in grads:
assert grad is not None, "Grad is None for variable {}".format(var.name)
# TODO make assert work # TODO make assert work
tf.Assert(tf.reduce_all(tf.is_finite(var)), [var]) tf.Assert(tf.reduce_all(tf.is_finite(var)), [var])
return grads return grads
......
...@@ -191,13 +191,13 @@ class QueueInputTrainer(Trainer): ...@@ -191,13 +191,13 @@ class QueueInputTrainer(Trainer):
grads = QueueInputTrainer._average_grads(grad_list) grads = QueueInputTrainer._average_grads(grad_list)
grads = self.process_grads(grads) grads = self.process_grads(grads)
else: else:
grad_list = [self.process_grads(g) for g in grad_list]
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
def scale(grads): def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads] return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list) grad_list = map(scale, grad_list)
grads = grad_list[0] # use grad from the first tower for routinely stuff grad_list = [self.process_grads(g) for g in grad_list]
grads = grad_list[0] # use grad from the first tower for the main iteration
else: else:
grads = self._single_tower_grad() grads = self._single_tower_grad()
grads = self.process_grads(grads) grads = self.process_grads(grads)
...@@ -207,6 +207,7 @@ class QueueInputTrainer(Trainer): ...@@ -207,6 +207,7 @@ class QueueInputTrainer(Trainer):
summary_moving_average()) summary_moving_average())
if self.async: if self.async:
# prepare train_op for the rest of the towers
self.threads = [] self.threads = []
for k in range(1, self.config.nr_tower): for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k]) train_op = self.config.optimizer.apply_gradients(grad_list[k])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment