Commit 58a41fca authored by Yuxin Wu's avatar Yuxin Wu

simulator improved

parent b2fd9b0d
......@@ -88,7 +88,7 @@ class Evaluator(Callback):
self.input_names = input_names
self.output_names = output_names
def _before_train(self):
def _setup_graph(self):
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
self.pred_funcs = [self.trainer.get_predict_func(
self.input_names, self.output_names)] * NR_PROC
......
......@@ -3,6 +3,7 @@
# File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import multiprocessing as mp
import time
import threading
......@@ -13,12 +14,17 @@ import numpy as np
import six
from six.moves import queue
from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor
from ..utils import logger
from ..utils.timer import *
from ..utils.serialize import *
from ..utils.concurrency import *
__all__ = ['SimulatorProcess', 'SimulatorMaster',
'StateExchangeSimulatorProcess', 'SimulatorProcessSharedWeight']
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
'TransitionExperience', 'WeightSync']
try:
import zmq
......@@ -26,6 +32,16 @@ except ImportError:
logger.warn("Error in 'import zmq'. RL simulator won't be available.")
__all__ = []
class TransitionExperience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward, **kwargs):
""" kwargs: whatever other attribute you want to save"""
self.state = state
self.action = action
self.reward = reward
for k, v in six.iteritems(kwargs):
setattr(self, k, v)
class SimulatorProcessBase(mp.Process):
__metaclass__ = ABCMeta
......@@ -39,7 +55,7 @@ class SimulatorProcessBase(mp.Process):
pass
class StateExchangeSimulatorProcess(SimulatorProcessBase):
class SimulatorProcessStateExchange(SimulatorProcessBase):
"""
A process that simulates a player and communicates to master to
send states and receive the next action
......@@ -50,7 +66,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
"""
:param idx: idx of this process
"""
super(StateExchangeSimulatorProcess, self).__init__(idx)
super(SimulatorProcessStateExchange, self).__init__(idx)
self.c2s = pipe_c2s
self.s2c = pipe_s2c
......@@ -78,7 +94,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
state = player.current_state()
# compatibility
SimulatorProcess = StateExchangeSimulatorProcess
SimulatorProcess = SimulatorProcessStateExchange
class SimulatorMaster(threading.Thread):
""" A base thread to communicate with all StateExchangeSimulatorProcess.
......@@ -91,16 +107,6 @@ class SimulatorMaster(threading.Thread):
def __init__(self):
self.memory = [] # list of Experience
class Experience(object):
""" A transition of state, or experience"""
def __init__(self, state, action, reward, **kwargs):
""" kwargs: whatever other attribute you want to save"""
self.state = state
self.action = action
self.reward = reward
for k, v in six.iteritems(kwargs):
setattr(self, k, v)
def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__()
self.daemon = True
......@@ -170,8 +176,7 @@ class SimulatorMaster(threading.Thread):
"""
def __del__(self):
self.socket.close()
self.context.term()
self.context.destroy(linger=0)
class SimulatorProcessDF(SimulatorProcessBase):
......@@ -191,18 +196,15 @@ class SimulatorProcessDF(SimulatorProcessBase):
self.c2s_socket.connect(self.pipe_c2s)
self._prepare()
while True:
dp = self._produce_datapoint()
self.c2s_socket.send(dumps(
(self.identity, dp)
), copy=False)
for dp in self.get_data():
self.c2s_socket.send(dumps(dp), copy=False)
@abstractmethod
def _prepare(self):
pass
@abstractmethod
def _produce_datapoint(self):
def get_data(self):
pass
......@@ -212,31 +214,61 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
Start me under some CUDA_VISIBLE_DEVICES set!
"""
def __init__(self, idx, pipe_c2s, evt, shared_dic):
def __init__(self, idx, pipe_c2s, condvar, shared_dic, pred_config):
super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s)
self.evt = evt
self.condvar = condvar
self.shared_dic = shared_dic
self.pred_config = pred_config
def _prepare(self):
self._build_session()
self.predictor = OfflinePredictor(self.pred_config)
with self.predictor.graph.as_default():
vars_to_update = self._params_to_update()
self.sess_updater = SessionUpdate(
self.predictor.session, vars_to_update)
# TODO setup callback for explore?
self.predictor.graph.finalize()
# start a thread to wait for evt
self.weight_lock = threading.Lock()
# start a thread to wait for notification
def func():
self.evt.wait()
self.condvar.acquire()
while True:
self.condvar.wait()
self._trigger_evt()
self.evt_th = LoopThread(func, pausable=False)
self.evt_th = threading.Thread(target=func)
self.evt_th.daemon = True
self.evt_th.start()
@abstractmethod
def _trigger_evt(self):
pass
#self.sess_updater.update(self.shared_dic['params'])
with self.weight_lock:
self.sess_updater.update(self.shared_dic['params'])
@abstractmethod
def _build_session(self):
# build session and self.sess_updaer
pass
def _params_to_update(self):
# can be overwritten to update more params
return tf.trainable_variables()
class WeightSync(Callback):
""" Sync weight from main process to shared_dic and notify"""
def __init__(self, condvar, shared_dic):
self.condvar = condvar
self.shared_dic = shared_dic
def _setup_graph(self):
self.vars = self._params_to_update()
def _params_to_update(self):
# can be overwritten to update more params
return tf.trainable_variables()
def _trigger_epoch(self):
logger.info("Updating weights ...")
dic = {v.name: v.eval() for v in self.vars}
self.shared_dic['params'] = dic
self.condvar.acquire()
self.condvar.notify_all()
self.condvar.release()
if __name__ == '__main__':
import random
......
......@@ -6,8 +6,15 @@
import numpy as np
from six.moves import range
from .base import DataFlow, RNGDataFlow
from ..utils.serialize import loads
__all__ = ['FakeData', 'DataFromQueue', 'DataFromList']
try:
import zmq
except:
pass
else:
__all__.append('DataFromSocket')
class FakeData(RNGDataFlow):
""" Generate fake fixed data of given shapes"""
......@@ -43,7 +50,6 @@ class DataFromQueue(DataFlow):
while True:
yield self.queue.get()
class DataFromList(RNGDataFlow):
""" Produce data from a list"""
def __init__(self, lst, shuffle=True):
......@@ -63,3 +69,20 @@ class DataFromList(RNGDataFlow):
for k in idxs:
yield self.lst[k]
class DataFromSocket(DataFlow):
""" Produce data from a zmq socket"""
def __init__(self, socket_name):
self._name = socket_name
def get_data(self):
try:
ctx = zmq.Context()
socket = ctx.socket(zmq.PULL)
socket.bind(self._name)
while True:
dp = loads(socket.recv(copy=False))
yield dp
finally:
ctx.destroy(linger=0)
......@@ -61,9 +61,6 @@ class SessionUpdate(object):
for name, value in six.iteritems(prms):
assert name in self.assign_ops
for p, v, op in self.assign_ops[name]:
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
varshape = tuple(v.get_shape().as_list())
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
......@@ -71,13 +68,7 @@ class SessionUpdate(object):
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
self.sess.run(op, feed_dict={p: value})
if 'fc0/W' in name:
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
def dump_session_params(path):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment