Commit c5da59af authored by Yuxin Wu's avatar Yuxin Wu

minor fix for async

parent cc844ed4
......@@ -9,7 +9,7 @@ import os
import cv2
from collections import deque
from six.moves import range
from ..utils import get_rng, logger
from ..utils import get_rng, logger, memoized
from ..utils.stat import StatCounter
from .envbase import RLEnvironment
......@@ -21,6 +21,10 @@ except ImportError:
__all__ = ['AtariPlayer']
@memoized
def log_once():
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
class AtariPlayer(RLEnvironment):
"""
A wrapper for atari emulator.
......@@ -43,10 +47,12 @@ class AtariPlayer(RLEnvironment):
self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setBool("showinfo", False)
try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
except AttributeError:
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
log_once()
self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing
import time
import threading
import weakref
from abc import abstractmethod, ABCMeta
......@@ -68,10 +69,12 @@ class SimulatorMaster(threading.Thread):
class Experience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward):
def __init__(self, state, action, reward, misc=None):
""" misc: whatever other attribute you want to save"""
self.state = state
self.action = action
self.reward = reward
self.misc = misc
def __init__(self, server_name):
super(SimulatorMaster, self).__init__()
......@@ -91,7 +94,14 @@ class SimulatorMaster(threading.Thread):
def run(self):
self.clients = defaultdict(SimulatorMaster.ClientState)
while True:
ident, _, msg = self.socket.recv_multipart()
while True:
# avoid the lock being acquired here forever
try:
with self.socket_lock:
ident, _, msg = self.socket.recv_multipart(zmq.NOBLOCK)
break
except zmq.ZMQError:
time.sleep(0.01)
#assert _ == ""
client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state
......
......@@ -27,16 +27,23 @@ class GradientProcessor(object):
def _process(self, grads):
pass
_summaried_gradient = set()
class SummaryGradient(GradientProcessor):
"""
Summary history and RMS for each graident variable
"""
def _process(self, grads):
for grad, var in grads:
tf.histogram_summary(var.op.name + '/grad', grad)
name = var.op.name
if name in _summaried_gradient:
continue
_summaried_gradient.add(name)
tf.histogram_summary(name + '/grad', grad)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY,
tf.sqrt(tf.reduce_mean(tf.square(grad)),
name=var.op.name + '/gradRMS'))
name=name + '/gradRMS'))
return grads
......@@ -46,7 +53,6 @@ class CheckGradient(GradientProcessor):
"""
def _process(self, grads):
for grad, var in grads:
assert grad is not None, "Grad is None for variable {}".format(var.name)
# TODO make assert work
tf.Assert(tf.reduce_all(tf.is_finite(var)), [var])
return grads
......
......@@ -191,13 +191,13 @@ class QueueInputTrainer(Trainer):
grads = QueueInputTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
else:
grad_list = [self.process_grads(g) for g in grad_list]
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grads = grad_list[0] # use grad from the first tower for routinely stuff
grad_list = [self.process_grads(g) for g in grad_list]
grads = grad_list[0] # use grad from the first tower for the main iteration
else:
grads = self._single_tower_grad()
grads = self.process_grads(grads)
......@@ -207,6 +207,7 @@ class QueueInputTrainer(Trainer):
summary_moving_average())
if self.async:
# prepare train_op for the rest of the towers
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment