Commit 58a41fca authored by Yuxin Wu's avatar Yuxin Wu

simulator improved

parent b2fd9b0d
...@@ -88,7 +88,7 @@ class Evaluator(Callback): ...@@ -88,7 +88,7 @@ class Evaluator(Callback):
self.input_names = input_names self.input_names = input_names
self.output_names = output_names self.output_names = output_names
def _before_train(self): def _setup_graph(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func( self.pred_funcs = [self.trainer.get_predict_func(
self.input_names, self.output_names)] * NR_PROC self.input_names, self.output_names)] * NR_PROC
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# File: simulator.py # File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import multiprocessing as mp import multiprocessing as mp
import time import time
import threading import threading
...@@ -13,12 +14,17 @@ import numpy as np ...@@ -13,12 +14,17 @@ import numpy as np
import six import six
from six.moves import queue from six.moves import queue
from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor
from ..utils import logger
from ..utils.timer import * from ..utils.timer import *
from ..utils.serialize import * from ..utils.serialize import *
from ..utils.concurrency import * from ..utils.concurrency import *
__all__ = ['SimulatorProcess', 'SimulatorMaster', __all__ = ['SimulatorProcess', 'SimulatorMaster',
'StateExchangeSimulatorProcess', 'SimulatorProcessSharedWeight'] 'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
'TransitionExperience', 'WeightSync']
try: try:
import zmq import zmq
...@@ -26,6 +32,16 @@ except ImportError: ...@@ -26,6 +32,16 @@ except ImportError:
logger.warn("Error in 'import zmq'. RL simulator won't be available.") logger.warn("Error in 'import zmq'. RL simulator won't be available.")
__all__ = [] __all__ = []
class TransitionExperience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward, **kwargs):
""" kwargs: whatever other attribute you want to save"""
self.state = state
self.action = action
self.reward = reward
for k, v in six.iteritems(kwargs):
setattr(self, k, v)
class SimulatorProcessBase(mp.Process): class SimulatorProcessBase(mp.Process):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -39,7 +55,7 @@ class SimulatorProcessBase(mp.Process): ...@@ -39,7 +55,7 @@ class SimulatorProcessBase(mp.Process):
pass pass
class StateExchangeSimulatorProcess(SimulatorProcessBase): class SimulatorProcessStateExchange(SimulatorProcessBase):
""" """
A process that simulates a player and communicates to master to A process that simulates a player and communicates to master to
send states and receive the next action send states and receive the next action
...@@ -50,7 +66,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase): ...@@ -50,7 +66,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
""" """
:param idx: idx of this process :param idx: idx of this process
""" """
super(StateExchangeSimulatorProcess, self).__init__(idx) super(SimulatorProcessStateExchange, self).__init__(idx)
self.c2s = pipe_c2s self.c2s = pipe_c2s
self.s2c = pipe_s2c self.s2c = pipe_s2c
...@@ -78,7 +94,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase): ...@@ -78,7 +94,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
state = player.current_state() state = player.current_state()
# compatibility # compatibility
SimulatorProcess = StateExchangeSimulatorProcess SimulatorProcess = SimulatorProcessStateExchange
class SimulatorMaster(threading.Thread): class SimulatorMaster(threading.Thread):
""" A base thread to communicate with all StateExchangeSimulatorProcess. """ A base thread to communicate with all StateExchangeSimulatorProcess.
...@@ -91,16 +107,6 @@ class SimulatorMaster(threading.Thread): ...@@ -91,16 +107,6 @@ class SimulatorMaster(threading.Thread):
def __init__(self): def __init__(self):
self.memory = [] # list of Experience self.memory = [] # list of Experience
class Experience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward, **kwargs):
""" kwargs: whatever other attribute you want to save"""
self.state = state
self.action = action
self.reward = reward
for k, v in six.iteritems(kwargs):
setattr(self, k, v)
def __init__(self, pipe_c2s, pipe_s2c): def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__() super(SimulatorMaster, self).__init__()
self.daemon = True self.daemon = True
...@@ -170,8 +176,7 @@ class SimulatorMaster(threading.Thread): ...@@ -170,8 +176,7 @@ class SimulatorMaster(threading.Thread):
""" """
def __del__(self): def __del__(self):
self.socket.close() self.context.destroy(linger=0)
self.context.term()
class SimulatorProcessDF(SimulatorProcessBase): class SimulatorProcessDF(SimulatorProcessBase):
...@@ -191,18 +196,15 @@ class SimulatorProcessDF(SimulatorProcessBase): ...@@ -191,18 +196,15 @@ class SimulatorProcessDF(SimulatorProcessBase):
self.c2s_socket.connect(self.pipe_c2s) self.c2s_socket.connect(self.pipe_c2s)
self._prepare() self._prepare()
while True: for dp in self.get_data():
dp = self._produce_datapoint() self.c2s_socket.send(dumps(dp), copy=False)
self.c2s_socket.send(dumps(
(self.identity, dp)
), copy=False)
@abstractmethod @abstractmethod
def _prepare(self): def _prepare(self):
pass pass
@abstractmethod @abstractmethod
def _produce_datapoint(self): def get_data(self):
pass pass
...@@ -212,31 +214,61 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF): ...@@ -212,31 +214,61 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
Start me under some CUDA_VISIBLE_DEVICES set! Start me under some CUDA_VISIBLE_DEVICES set!
""" """
def __init__(self, idx, pipe_c2s, evt, shared_dic): def __init__(self, idx, pipe_c2s, condvar, shared_dic, pred_config):
super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s) super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s)
self.evt = evt self.condvar = condvar
self.shared_dic = shared_dic self.shared_dic = shared_dic
self.pred_config = pred_config
def _prepare(self): def _prepare(self):
self._build_session() self.predictor = OfflinePredictor(self.pred_config)
with self.predictor.graph.as_default():
vars_to_update = self._params_to_update()
self.sess_updater = SessionUpdate(
self.predictor.session, vars_to_update)
# TODO setup callback for explore?
self.predictor.graph.finalize()
# start a thread to wait for evt self.weight_lock = threading.Lock()
# start a thread to wait for notification
def func(): def func():
self.evt.wait() self.condvar.acquire()
self._trigger_evt() while True:
self.evt_th = LoopThread(func, pausable=False) self.condvar.wait()
self._trigger_evt()
self.evt_th = threading.Thread(target=func)
self.evt_th.daemon = True
self.evt_th.start() self.evt_th.start()
@abstractmethod
def _trigger_evt(self): def _trigger_evt(self):
pass with self.weight_lock:
#self.sess_updater.update(self.shared_dic['params']) self.sess_updater.update(self.shared_dic['params'])
@abstractmethod def _params_to_update(self):
def _build_session(self): # can be overwritten to update more params
# build session and self.sess_updaer return tf.trainable_variables()
pass
class WeightSync(Callback):
""" Sync weight from main process to shared_dic and notify"""
def __init__(self, condvar, shared_dic):
self.condvar = condvar
self.shared_dic = shared_dic
def _setup_graph(self):
self.vars = self._params_to_update()
def _params_to_update(self):
# can be overwritten to update more params
return tf.trainable_variables()
def _trigger_epoch(self):
logger.info("Updating weights ...")
dic = {v.name: v.eval() for v in self.vars}
self.shared_dic['params'] = dic
self.condvar.acquire()
self.condvar.notify_all()
self.condvar.release()
if __name__ == '__main__': if __name__ == '__main__':
import random import random
......
...@@ -6,8 +6,15 @@ ...@@ -6,8 +6,15 @@
import numpy as np import numpy as np
from six.moves import range from six.moves import range
from .base import DataFlow, RNGDataFlow from .base import DataFlow, RNGDataFlow
from ..utils.serialize import loads
__all__ = ['FakeData', 'DataFromQueue', 'DataFromList'] __all__ = ['FakeData', 'DataFromQueue', 'DataFromList']
try:
import zmq
except:
pass
else:
__all__.append('DataFromSocket')
class FakeData(RNGDataFlow): class FakeData(RNGDataFlow):
""" Generate fake fixed data of given shapes""" """ Generate fake fixed data of given shapes"""
...@@ -43,7 +50,6 @@ class DataFromQueue(DataFlow): ...@@ -43,7 +50,6 @@ class DataFromQueue(DataFlow):
while True: while True:
yield self.queue.get() yield self.queue.get()
class DataFromList(RNGDataFlow): class DataFromList(RNGDataFlow):
""" Produce data from a list""" """ Produce data from a list"""
def __init__(self, lst, shuffle=True): def __init__(self, lst, shuffle=True):
...@@ -63,3 +69,20 @@ class DataFromList(RNGDataFlow): ...@@ -63,3 +69,20 @@ class DataFromList(RNGDataFlow):
for k in idxs: for k in idxs:
yield self.lst[k] yield self.lst[k]
class DataFromSocket(DataFlow):
""" Produce data from a zmq socket"""
def __init__(self, socket_name):
self._name = socket_name
def get_data(self):
try:
ctx = zmq.Context()
socket = ctx.socket(zmq.PULL)
socket.bind(self._name)
while True:
dp = loads(socket.recv(copy=False))
yield dp
finally:
ctx.destroy(linger=0)
...@@ -61,9 +61,6 @@ class SessionUpdate(object): ...@@ -61,9 +61,6 @@ class SessionUpdate(object):
for name, value in six.iteritems(prms): for name, value in six.iteritems(prms):
assert name in self.assign_ops assert name in self.assign_ops
for p, v, op in self.assign_ops[name]: for p, v, op in self.assign_ops[name]:
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
varshape = tuple(v.get_shape().as_list()) varshape = tuple(v.get_shape().as_list())
if varshape != value.shape: if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis # TODO only allow reshape when shape different by empty axis
...@@ -71,13 +68,7 @@ class SessionUpdate(object): ...@@ -71,13 +68,7 @@ class SessionUpdate(object):
"{}: {}!={}".format(name, varshape, value.shape) "{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name)) logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape) value = value.reshape(varshape)
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
self.sess.run(op, feed_dict={p: value}) self.sess.run(op, feed_dict={p: value})
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
def dump_session_params(path): def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as """ Dump value of all trainable + to_save variables to a dict and save to `path` as
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment