Commit c52c8cc4 authored by Yuxin Wu's avatar Yuxin Wu

Warn about 'fork' multiprocessing start method (#1134)

parent 839ca071
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: concurrency.py # File: concurrency.py
import multiprocessing as mp import multiprocessing as mp
from ..utils import logger from ..utils import logger
......
...@@ -53,8 +53,7 @@ class GPUUtilizationTracker(Callback): ...@@ -53,8 +53,7 @@ class GPUUtilizationTracker(Callback):
self._devices = devices self._devices = devices
assert len(self._devices), "[GPUUtilizationTracker] No GPU device given!" assert len(self._devices), "[GPUUtilizationTracker] No GPU device given!"
def _before_train(self): def _setup_graph(self):
assert gpu_available_in_session(), "[GPUUtilizationTracker] needs GPU!"
self._evt = mp.Event() self._evt = mp.Event()
self._stop_evt = mp.Event() self._stop_evt = mp.Event()
self._queue = mp.Queue() self._queue = mp.Queue()
...@@ -63,6 +62,9 @@ class GPUUtilizationTracker(Callback): ...@@ -63,6 +62,9 @@ class GPUUtilizationTracker(Callback):
ensure_proc_terminate(self._proc) ensure_proc_terminate(self._proc)
start_proc_mask_signal(self._proc) start_proc_mask_signal(self._proc)
def _before_train(self):
assert gpu_available_in_session(), "[GPUUtilizationTracker] needs GPU!"
def _before_epoch(self): def _before_epoch(self):
self._evt.set() self._evt.set()
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
# Some code taken from zxytim # Some code taken from zxytim
import sys
import atexit import atexit
import bisect import bisect
import multiprocessing import multiprocessing as mp
import platform import platform
import signal import signal
import threading import threading
...@@ -169,7 +170,7 @@ def ensure_proc_terminate(proc): ...@@ -169,7 +170,7 @@ def ensure_proc_terminate(proc):
proc.terminate() proc.terminate()
proc.join() proc.join()
assert isinstance(proc, multiprocessing.Process) assert isinstance(proc, mp.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
...@@ -223,7 +224,7 @@ def start_proc_mask_signal(proc): ...@@ -223,7 +224,7 @@ def start_proc_mask_signal(proc):
Start process(es) with SIGINT ignored. Start process(es) with SIGINT ignored.
Args: Args:
proc: (multiprocessing.Process or list) proc: (mp.Process or list)
Note: Note:
The signal mask is only applied when called from main thread. The signal mask is only applied when called from main thread.
...@@ -233,6 +234,13 @@ def start_proc_mask_signal(proc): ...@@ -233,6 +234,13 @@ def start_proc_mask_signal(proc):
with mask_sigint(): with mask_sigint():
for p in proc: for p in proc:
if isinstance(p, mp.Process):
if sys.version_info < (3, 4) or mp.get_start_method() == 'fork':
log_once(
"Starting a process with 'fork' method is not safe and may consume unnecessary extra memory."
" Use 'forkserver' method (available after Py3.4) instead if you run into any issues. "
"See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods",
'warn') # noqa
p.start() p.start()
...@@ -308,7 +316,7 @@ class OrderedContainer(object): ...@@ -308,7 +316,7 @@ class OrderedContainer(object):
return rank, ret return rank, ret
class OrderedResultGatherProc(multiprocessing.Process): class OrderedResultGatherProc(mp.Process):
""" """
Gather indexed data from a data queue, and produce results with the Gather indexed data from a data queue, and produce results with the
original index-based order. original index-based order.
...@@ -317,7 +325,7 @@ class OrderedResultGatherProc(multiprocessing.Process): ...@@ -317,7 +325,7 @@ class OrderedResultGatherProc(multiprocessing.Process):
def __init__(self, data_queue, nr_producer, start=0): def __init__(self, data_queue, nr_producer, start=0):
""" """
Args: Args:
data_queue(multiprocessing.Queue): a queue which contains datapoints. data_queue(mp.Queue): a queue which contains datapoints.
nr_producer(int): number of producer processes. This process will nr_producer(int): number of producer processes. This process will
terminate after receiving this many of :class:`DIE` sentinel. terminate after receiving this many of :class:`DIE` sentinel.
start(int): the rank of the first object start(int): the rank of the first object
...@@ -325,7 +333,7 @@ class OrderedResultGatherProc(multiprocessing.Process): ...@@ -325,7 +333,7 @@ class OrderedResultGatherProc(multiprocessing.Process):
super(OrderedResultGatherProc, self).__init__() super(OrderedResultGatherProc, self).__init__()
self.data_queue = data_queue self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start) self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue() self.result_queue = mp.Queue()
self.nr_producer = nr_producer self.nr_producer = nr_producer
def run(self): def run(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment