Commit 6636791f authored by Yuxin Wu's avatar Yuxin Wu

start buildling simulator frameworks

parent fc2e6240
...@@ -13,7 +13,7 @@ See some interesting [examples](examples) to learn about the framework: ...@@ -13,7 +13,7 @@ See some interesting [examples](examples) to learn about the framework:
## Features: ## Features:
You need to abstract your training task into three components: Abstract your training task into three components:
1. Model, or graph. `models/` has some scoped abstraction of common models. 1. Model, or graph. `models/` has some scoped abstraction of common models.
`LinearWrap` and `argscope` makes large models look simpler. `LinearWrap` and `argscope` makes large models look simpler.
...@@ -43,8 +43,8 @@ Multi-GPU training is ready to use by simply switching the trainer. ...@@ -43,8 +43,8 @@ Multi-GPU training is ready to use by simply switching the trainer.
pip install --user -r requirements.txt pip install --user -r requirements.txt
pip install --user -r opt-requirements.txt (some optional dependencies, you can install later if needed) pip install --user -r opt-requirements.txt (some optional dependencies, you can install later if needed)
``` ```
+ Use [tcmalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) whenever possible: see [TF issue](https://github.com/tensorflow/tensorflow/issues/2942) + Use [tcmalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) whenever possible
+ allow `import tensorpack` everywhere: + Enable `import tensorpack`:
``` ```
export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack` export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack`
``` ```
# tensorpack examples # tensorpack examples
Only allow examples with reproducible and meaningful performancce. Only allow examples with __reproducible__ and meaningful performancce.
+ [An illustrative mnist example](mnist-convnet.py) + [An illustrative mnist example](mnist-convnet.py)
+ [A small Cifar10 ConvNet with 91% accuracy](cifar-convnet.py) + [A small Cifar10 ConvNet with 91% accuracy](cifar-convnet.py)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# File: simulator.py # File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing import multiprocessing as mp
import time import time
import threading import threading
import weakref import weakref
...@@ -17,7 +17,8 @@ from ..utils.timer import * ...@@ -17,7 +17,8 @@ from ..utils.timer import *
from ..utils.serialize import * from ..utils.serialize import *
from ..utils.concurrency import * from ..utils.concurrency import *
__all__ = ['SimulatorProcess', 'SimulatorMaster'] __all__ = ['SimulatorProcess', 'SimulatorMaster',
'StateExchangeSimulatorProcess', 'SimulatorProcessSharedWeight']
try: try:
import zmq import zmq
...@@ -25,10 +26,23 @@ except ImportError: ...@@ -25,10 +26,23 @@ except ImportError:
logger.warn("Error in 'import zmq'. RL simulator won't be available.") logger.warn("Error in 'import zmq'. RL simulator won't be available.")
__all__ = [] __all__ = []
class SimulatorProcessBase(mp.Process):
__metaclass__ = ABCMeta
def __init__(self, idx):
super(SimulatorProcessBase, self).__init__()
self.idx = int(idx)
self.identity = u'simulator-{}'.format(self.idx).encode('utf-8')
class SimulatorProcess(multiprocessing.Process): @abstractmethod
def _build_player(self):
pass
class StateExchangeSimulatorProcess(SimulatorProcessBase):
""" """
A process that simulates a player and communicates to master to get the next action A process that simulates a player and communicates to master to
send states and receive the next action
""" """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -36,13 +50,10 @@ class SimulatorProcess(multiprocessing.Process): ...@@ -36,13 +50,10 @@ class SimulatorProcess(multiprocessing.Process):
""" """
:param idx: idx of this process :param idx: idx of this process
""" """
super(SimulatorProcess, self).__init__() super(StateExchangeSimulatorProcess, self).__init__(idx)
self.idx = int(idx)
self.c2s = pipe_c2s self.c2s = pipe_c2s
self.s2c = pipe_s2c self.s2c = pipe_s2c
self.identity = u'simulator-{}'.format(self.idx).encode('utf-8')
def run(self): def run(self):
player = self._build_player() player = self._build_player()
context = zmq.Context() context = zmq.Context()
...@@ -66,12 +77,11 @@ class SimulatorProcess(multiprocessing.Process): ...@@ -66,12 +77,11 @@ class SimulatorProcess(multiprocessing.Process):
reward, isOver = player.action(action) reward, isOver = player.action(action)
state = player.current_state() state = player.current_state()
@abstractmethod # compatibility
def _build_player(self): SimulatorProcess = StateExchangeSimulatorProcess
pass
class SimulatorMaster(threading.Thread): class SimulatorMaster(threading.Thread):
""" A base thread to communicate with all simulator processes. """ A base thread to communicate with all StateExchangeSimulatorProcess.
It should produce action for each simulator, as well as It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished. defining callbacks when a transition or an episode is finished.
""" """
...@@ -163,6 +173,71 @@ class SimulatorMaster(threading.Thread): ...@@ -163,6 +173,71 @@ class SimulatorMaster(threading.Thread):
self.socket.close() self.socket.close()
self.context.term() self.context.term()
class SimulatorProcessDF(SimulatorProcessBase):
""" A simulator which contains a forward model itself, allowing
it to produce data points directly """
def __init__(self, idx, pipe_c2s):
super(SimulatorProcessDF, self).__init__(idx)
self.pipe_c2s = pipe_c2s
def run(self):
self.player = self._build_player()
self.ctx = zmq.Context()
self.c2s_socket = self.ctx.socket(zmq.PUSH)
self.c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
self.c2s_socket.set_hwm(5)
self.c2s_socket.connect(self.pipe_c2s)
self._prepare()
while True:
dp = self._produce_datapoint()
self.c2s_socket.send(dumps(
(self.identity, dp)
), copy=False)
@abstractmethod
def _prepare(self):
pass
@abstractmethod
def _produce_datapoint(self):
pass
class SimulatorProcessSharedWeight(SimulatorProcessDF):
""" A simulator process with an extra thread waiting for event,
and take shared weight from shm.
Start me under some CUDA_VISIBLE_DEVICES set!
"""
def __init__(self, idx, pipe_c2s, evt, shared_dic):
super(SimulatorProcessSharedWeight, self).__init__(idx, pipe_c2s)
self.evt = evt
self.shared_dic = shared_dic
def _prepare(self):
self._build_session()
# start a thread to wait for evt
def func():
self.evt.wait()
self._trigger_evt()
self.evt_th = LoopThread(func, pausable=False)
self.evt_th.start()
@abstractmethod
def _trigger_evt(self):
pass
#self.sess_updater.update(self.shared_dic['params'])
@abstractmethod
def _build_session(self):
# build session and self.sess_updaer
pass
if __name__ == '__main__': if __name__ == '__main__':
import random import random
from tensorpack.RL import NaiveRLEnvironment from tensorpack.RL import NaiveRLEnvironment
......
...@@ -115,7 +115,6 @@ class Callbacks(Callback): ...@@ -115,7 +115,6 @@ class Callbacks(Callback):
cbs.remove(sp) cbs.remove(sp)
cbs.append(sp) cbs.append(sp)
break break
print(cbs)
self.cbs = cbs self.cbs = cbs
self.test_callback_context = TestCallbackContext() self.test_callback_context = TestCallbackContext()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment