Commit c52c8cc4 authored by Yuxin Wu's avatar Yuxin Wu

Warn about 'fork' multiprocessing start method (#1134)

parent 839ca071
# -*- coding: utf-8 -*-
# File: concurrency.py
import multiprocessing as mp
from ..utils import logger
......
......@@ -53,8 +53,7 @@ class GPUUtilizationTracker(Callback):
self._devices = devices
assert len(self._devices), "[GPUUtilizationTracker] No GPU device given!"
def _before_train(self):
assert gpu_available_in_session(), "[GPUUtilizationTracker] needs GPU!"
def _setup_graph(self):
self._evt = mp.Event()
self._stop_evt = mp.Event()
self._queue = mp.Queue()
......@@ -63,6 +62,9 @@ class GPUUtilizationTracker(Callback):
ensure_proc_terminate(self._proc)
start_proc_mask_signal(self._proc)
def _before_train(self):
assert gpu_available_in_session(), "[GPUUtilizationTracker] needs GPU!"
def _before_epoch(self):
self._evt.set()
......
......@@ -3,9 +3,10 @@
# Some code taken from zxytim
import sys
import atexit
import bisect
import multiprocessing
import multiprocessing as mp
import platform
import signal
import threading
......@@ -169,7 +170,7 @@ def ensure_proc_terminate(proc):
proc.terminate()
proc.join()
assert isinstance(proc, multiprocessing.Process)
assert isinstance(proc, mp.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
......@@ -223,7 +224,7 @@ def start_proc_mask_signal(proc):
Start process(es) with SIGINT ignored.
Args:
proc: (multiprocessing.Process or list)
proc: (mp.Process or list)
Note:
The signal mask is only applied when called from main thread.
......@@ -233,6 +234,13 @@ def start_proc_mask_signal(proc):
with mask_sigint():
for p in proc:
if isinstance(p, mp.Process):
if sys.version_info < (3, 4) or mp.get_start_method() == 'fork':
log_once(
"Starting a process with 'fork' method is not safe and may consume unnecessary extra memory."
" Use 'forkserver' method (available after Py3.4) instead if you run into any issues. "
"See https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods",
'warn') # noqa
p.start()
......@@ -308,7 +316,7 @@ class OrderedContainer(object):
return rank, ret
class OrderedResultGatherProc(multiprocessing.Process):
class OrderedResultGatherProc(mp.Process):
"""
Gather indexed data from a data queue, and produce results with the
original index-based order.
......@@ -317,7 +325,7 @@ class OrderedResultGatherProc(multiprocessing.Process):
def __init__(self, data_queue, nr_producer, start=0):
"""
Args:
data_queue(multiprocessing.Queue): a queue which contains datapoints.
data_queue(mp.Queue): a queue which contains datapoints.
nr_producer(int): number of producer processes. This process will
terminate after receiving this many of :class:`DIE` sentinel.
start(int): the rank of the first object
......@@ -325,7 +333,7 @@ class OrderedResultGatherProc(multiprocessing.Process):
super(OrderedResultGatherProc, self).__init__()
self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue()
self.result_queue = mp.Queue()
self.nr_producer = nr_producer
def run(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment