Commit 6636791f authored by Yuxin Wu's avatar Yuxin Wu

start buildling simulator frameworks

parent fc2e6240
......@@ -13,7 +13,7 @@ See some interesting [examples](examples) to learn about the framework:
## Features:
You need to abstract your training task into three components:
Abstract your training task into three components:
1. Model, or graph. `models/` has some scoped abstraction of common models.
`LinearWrap` and `argscope` makes large models look simpler.
......@@ -43,8 +43,8 @@ Multi-GPU training is ready to use by simply switching the trainer.
pip install --user -r requirements.txt
pip install --user -r opt-requirements.txt (some optional dependencies, you can install later if needed)
```
+ Use [tcmalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) whenever possible: see [TF issue](https://github.com/tensorflow/tensorflow/issues/2942)
+ allow `import tensorpack` everywhere:
+ Use [tcmalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) whenever possible
+ Enable `import tensorpack`:
```
export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack`
```
# tensorpack examples
Only allow examples with reproducible and meaningful performancce.
Only allow examples with __reproducible__ and meaningful performancce.
+ [An illustrative mnist example](mnist-convnet.py)
+ [A small Cifar10 ConvNet with 91% accuracy](cifar-convnet.py)
......
......@@ -3,7 +3,7 @@
# File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing
import multiprocessing as mp
import time
import threading
import weakref
......@@ -17,7 +17,8 @@ from ..utils.timer import *
from ..utils.serialize import *
from ..utils.concurrency import *
__all__ = ['SimulatorProcess', 'SimulatorMaster']
__all__ = ['SimulatorProcess', 'SimulatorMaster',
'StateExchangeSimulatorProcess', 'SimulatorProcessSharedWeight']
try:
import zmq
......@@ -25,10 +26,23 @@ except ImportError:
logger.warn("Error in 'import zmq'. RL simulator won't be available.")
__all__ = []
class SimulatorProcessBase(mp.Process):
__metaclass__ = ABCMeta
def __init__(self, idx):
super(SimulatorProcessBase, self).__init__()
self.idx = int(idx)
self.identity = u'simulator-{}'.format(self.idx).encode('utf-8')
class SimulatorProcess(multiprocessing.Process):
@abstractmethod
def _build_player(self):
pass
class StateExchangeSimulatorProcess(SimulatorProcessBase):
"""
A process that simulates a player and communicates to master to get the next action
A process that simulates a player and communicates to master to
send states and receive the next action
"""
__metaclass__ = ABCMeta
......@@ -36,13 +50,10 @@ class SimulatorProcess(multiprocessing.Process):
"""
:param idx: idx of this process
"""
super(SimulatorProcess, self).__init__()
self.idx = int(idx)
super(StateExchangeSimulatorProcess, self).__init__(idx)
self.c2s = pipe_c2s
self.s2c = pipe_s2c
self.identity = u'simulator-{}'.format(self.idx).encode('utf-8')
def run(self):
player = self._build_player()
context = zmq.Context()
......@@ -66,12 +77,11 @@ class SimulatorProcess(multiprocessing.Process):
reward, isOver = player.action(action)
state = player.current_state()
@abstractmethod
def _build_player(self):
pass
# compatibility
SimulatorProcess = StateExchangeSimulatorProcess
class SimulatorMaster(threading.Thread):
""" A base thread to communicate with all simulator processes.
""" A base thread to communicate with all StateExchangeSimulatorProcess.
It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished.
"""
......@@ -163,6 +173,71 @@ class SimulatorMaster(threading.Thread):
self.socket.close()
self.context.term()
class SimulatorProcessDF(SimulatorProcessBase):
""" A simulator which contains a forward model itself, allowing
it to produce data points directly """
def __init__(self, idx, pipe_c2s):
super(SimulatorProcessDF, self).__init__(idx)
self.pipe_c2s = pipe_c2s
def run(self):
self.player = self._build_player()
self.ctx = zmq.Context()
self.c2s_socket = self.ctx.socket(zmq.PUSH)
self.c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
self.c2s_socket.set_hwm(5)
self.c2s_socket.connect(self.pipe_c2s)
self._prepare()
while True:
dp = self._produce_datapoint()
self.c2s_socket.send(dumps(
(self.identity, dp)
), copy=False)
@abstractmethod
def _prepare(self):
pass
@abstractmethod
def _produce_datapoint(self):
pass
class SimulatorProcessSharedWeight(SimulatorProcessDF):
""" A simulator process with an extra thread waiting for event,
and take shared weight from shm.
Start me under some CUDA_VISIBLE_DEVICES set!
"""
def __init__(self, idx, pipe_c2s, evt, shared_dic):
super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s)
self.evt = evt
self.shared_dic = shared_dic
def _prepare(self):
self._build_session()
# start a thread to wait for evt
def func():
self.evt.wait()
self._trigger_evt()
self.evt_th = LoopThread(func, pausable=False)
self.evt_th.start()
@abstractmethod
def _trigger_evt(self):
pass
#self.sess_updater.update(self.shared_dic['params'])
@abstractmethod
def _build_session(self):
# build session and self.sess_updaer
pass
if __name__ == '__main__':
import random
from tensorpack.RL import NaiveRLEnvironment
......
......@@ -115,7 +115,6 @@ class Callbacks(Callback):
cbs.remove(sp)
cbs.append(sp)
break
print(cbs)
self.cbs = cbs
self.test_callback_context = TestCallbackContext()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment