Commit 0a563252 authored by Yuxin Wu's avatar Yuxin Wu

catch some errors and stop the proecsses when training exited

parent ff460491
......@@ -144,6 +144,7 @@ class SimulatorMaster(threading.Thread):
def run(self):
self.clients = defaultdict(self.ClientState)
try:
while True:
msg = loads(self.c2s_socket.recv(copy=False).bytes)
ident, state, reward, isOver = msg
......@@ -160,6 +161,8 @@ class SimulatorMaster(threading.Thread):
self._on_datapoint(ident)
# feed state and return action
self._on_state(state, ident)
except zmq.ContextTerminated:
logger.info("[Simulator] Context was terminated.")
@abstractmethod
def _on_state(self, state, ident):
......
......@@ -138,7 +138,6 @@ class Model(ModelDesc):
class MySimulatorMaster(SimulatorMaster, Callback):
def __init__(self, pipe_c2s, pipe_s2c, model):
super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
self.M = model
......
......@@ -17,8 +17,8 @@ def get_dorefa(bitW, bitA, bitG):
def quantize(x, k):
n = float(2**k - 1)
with G.gradient_override_map({"Floor": "Identity"}):
return tf.floor(x * n + 0.5) / n
with G.gradient_override_map({"Round": "Identity"}):
return tf.round(x * n) / n
def fw(x):
if bitW == 32:
......
......@@ -3,8 +3,9 @@
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing as mp
from .base import Callback
from ..utils.concurrency import start_proc_mask_signal
from ..utils.concurrency import start_proc_mask_signal, StoppableThread
from ..utils import logger
__all__ = ['StartProcOrThread']
......@@ -15,18 +16,37 @@ class StartProcOrThread(Callback):
Start some threads or processes before training.
"""
def __init__(self, startable):
def __init__(self, startable, stop_at_last=True):
"""
Args:
startable(list): list of processes or threads which have ``start()`` method.
startable (list): list of processes or threads which have ``start()`` method.
Can also be a single instance of process of thread.
stop_at_last (bool): whether to stop the processes or threads
after training. It will use :meth:`Process.terminate()` or
:meth:`StoppableThread.stop()`, but will do nothing on normal
`threading.Thread` or other startable objects.
"""
if not isinstance(startable, list):
startable = [startable]
self._procs_threads = startable
self._stop_at_last = stop_at_last
def _before_train(self):
logger.info("Starting " +
', '.join([k.name for k in self._procs_threads]) + ' ...')
', '.join([k.name for k in self._procs_threads]))
# avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads)
def _after_train(self):
if not self._stop_at_last:
return
for k in self._procs_threads:
if isinstance(k, mp.Process):
k.terminate()
k.join()
elif isinstance(k, StoppableThread):
k.stop()
else:
logger.warn(
"[StartProcOrThread] {} "
"is neither a Process nor a StoppableThread, won't stop it.".format(k.name))
......@@ -4,11 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing
import threading
import six
from six.moves import queue, range
from ..utils.concurrency import DIE
from ..utils.concurrency import DIE, StoppableThread
from ..tfutils.modelutils import describe_model
from ..utils import logger
......@@ -83,8 +82,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.predictor(dp)))
class PredictorWorkerThread(threading.Thread):
class PredictorWorkerThread(StoppableThread):
def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__()
self.queue = queue
......@@ -94,7 +92,7 @@ class PredictorWorkerThread(threading.Thread):
self.id = id
def run(self):
while True:
while not self.stopped():
batched, futures = self.fetch_batch()
outputs = self.func(batched)
# print "Worker {} batched {} Queue {}".format(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment