Commit 0a563252 authored by Yuxin Wu's avatar Yuxin Wu

catch some errors and stop the proecsses when training exited

parent ff460491
...@@ -144,22 +144,25 @@ class SimulatorMaster(threading.Thread): ...@@ -144,22 +144,25 @@ class SimulatorMaster(threading.Thread):
def run(self): def run(self):
self.clients = defaultdict(self.ClientState) self.clients = defaultdict(self.ClientState)
while True: try:
msg = loads(self.c2s_socket.recv(copy=False).bytes) while True:
ident, state, reward, isOver = msg msg = loads(self.c2s_socket.recv(copy=False).bytes)
# TODO check history and warn about dead client ident, state, reward, isOver = msg
client = self.clients[ident] # TODO check history and warn about dead client
client = self.clients[ident]
# check if reward&isOver is valid
# in the first message, only state is valid # check if reward&isOver is valid
if len(client.memory) > 0: # in the first message, only state is valid
client.memory[-1].reward = reward if len(client.memory) > 0:
if isOver: client.memory[-1].reward = reward
self._on_episode_over(ident) if isOver:
else: self._on_episode_over(ident)
self._on_datapoint(ident) else:
# feed state and return action self._on_datapoint(ident)
self._on_state(state, ident) # feed state and return action
self._on_state(state, ident)
except zmq.ContextTerminated:
logger.info("[Simulator] Context was terminated.")
@abstractmethod @abstractmethod
def _on_state(self, state, ident): def _on_state(self, state, ident):
......
...@@ -138,7 +138,6 @@ class Model(ModelDesc): ...@@ -138,7 +138,6 @@ class Model(ModelDesc):
class MySimulatorMaster(SimulatorMaster, Callback): class MySimulatorMaster(SimulatorMaster, Callback):
def __init__(self, pipe_c2s, pipe_s2c, model): def __init__(self, pipe_c2s, pipe_s2c, model):
super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c) super(MySimulatorMaster, self).__init__(pipe_c2s, pipe_s2c)
self.M = model self.M = model
......
...@@ -17,8 +17,8 @@ def get_dorefa(bitW, bitA, bitG): ...@@ -17,8 +17,8 @@ def get_dorefa(bitW, bitA, bitG):
def quantize(x, k): def quantize(x, k):
n = float(2**k - 1) n = float(2**k - 1)
with G.gradient_override_map({"Floor": "Identity"}): with G.gradient_override_map({"Round": "Identity"}):
return tf.floor(x * n + 0.5) / n return tf.round(x * n) / n
def fw(x): def fw(x):
if bitW == 32: if bitW == 32:
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
# File: concurrency.py # File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing as mp
from .base import Callback from .base import Callback
from ..utils.concurrency import start_proc_mask_signal from ..utils.concurrency import start_proc_mask_signal, StoppableThread
from ..utils import logger from ..utils import logger
__all__ = ['StartProcOrThread'] __all__ = ['StartProcOrThread']
...@@ -15,18 +16,37 @@ class StartProcOrThread(Callback): ...@@ -15,18 +16,37 @@ class StartProcOrThread(Callback):
Start some threads or processes before training. Start some threads or processes before training.
""" """
def __init__(self, startable): def __init__(self, startable, stop_at_last=True):
""" """
Args: Args:
startable(list): list of processes or threads which have ``start()`` method. startable (list): list of processes or threads which have ``start()`` method.
Can also be a single instance of process of thread. Can also be a single instance of process of thread.
stop_at_last (bool): whether to stop the processes or threads
after training. It will use :meth:`Process.terminate()` or
:meth:`StoppableThread.stop()`, but will do nothing on normal
`threading.Thread` or other startable objects.
""" """
if not isinstance(startable, list): if not isinstance(startable, list):
startable = [startable] startable = [startable]
self._procs_threads = startable self._procs_threads = startable
self._stop_at_last = stop_at_last
def _before_train(self): def _before_train(self):
logger.info("Starting " + logger.info("Starting " +
', '.join([k.name for k in self._procs_threads]) + ' ...') ', '.join([k.name for k in self._procs_threads]))
# avoid sigint get handled by other processes # avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads) start_proc_mask_signal(self._procs_threads)
def _after_train(self):
if not self._stop_at_last:
return
for k in self._procs_threads:
if isinstance(k, mp.Process):
k.terminate()
k.join()
elif isinstance(k, StoppableThread):
k.stop()
else:
logger.warn(
"[StartProcOrThread] {} "
"is neither a Process nor a StoppableThread, won't stop it.".format(k.name))
...@@ -4,11 +4,10 @@ ...@@ -4,11 +4,10 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing import multiprocessing
import threading
import six import six
from six.moves import queue, range from six.moves import queue, range
from ..utils.concurrency import DIE from ..utils.concurrency import DIE, StoppableThread
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..utils import logger from ..utils import logger
...@@ -83,8 +82,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -83,8 +82,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.predictor(dp))) self.outqueue.put((tid, self.predictor(dp)))
class PredictorWorkerThread(threading.Thread): class PredictorWorkerThread(StoppableThread):
def __init__(self, queue, pred_func, id, batch_size=5): def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
self.queue = queue self.queue = queue
...@@ -94,7 +92,7 @@ class PredictorWorkerThread(threading.Thread): ...@@ -94,7 +92,7 @@ class PredictorWorkerThread(threading.Thread):
self.id = id self.id = id
def run(self): def run(self):
while True: while not self.stopped():
batched, futures = self.fetch_batch() batched, futures = self.fetch_batch()
outputs = self.func(batched) outputs = self.func(batched)
# print "Worker {} batched {} Queue {}".format( # print "Worker {} batched {} Queue {}".format(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment