Commit 8f10b0f8 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use "spawn" for safer multiprocessing

parent 2902bfbe
...@@ -93,7 +93,10 @@ _C.DATA.NUM_CATEGORY = 80 # without the background class (e.g., 80 for COCO) ...@@ -93,7 +93,10 @@ _C.DATA.NUM_CATEGORY = 80 # without the background class (e.g., 80 for COCO)
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG". _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1] # whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1]
_C.DATA.ABSOLUTE_COORD = True _C.DATA.ABSOLUTE_COORD = True
_C.DATA.NUM_WORKERS = 5 # number of data loading workers. set to 0 to disable parallel data loading # Number of data loading workers.
# In case of horovod training, this is the number of workers per-GPU (so you may want to use a smaller number).
# Set to 0 to disable parallel data loading
_C.DATA.NUM_WORKERS = 10
# backbone ---------------------- # backbone ----------------------
_C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz _C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz
......
This diff is collapsed.
...@@ -101,7 +101,11 @@ class ResNetC4Model(GeneralizedRCNN): ...@@ -101,7 +101,11 @@ class ResNetC4Model(GeneralizedRCNN):
def rpn(self, image, features, inputs): def rpn(self, image, features, inputs):
featuremap = features[0] featuremap = features[0]
rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR)
anchors = RPNAnchors(get_all_anchors(), inputs['anchor_labels'], inputs['anchor_boxes']) anchors = RPNAnchors(
get_all_anchors(
stride=cfg.RPN.ANCHOR_STRIDE, sizes=cfg.RPN.ANCHOR_SIZES,
ratios=cfg.RPN.ANCHOR_RATIOS, max_size=cfg.PREPROC.MAX_SIZE),
inputs['anchor_labels'], inputs['anchor_boxes'])
anchors = anchors.narrow_to(featuremap) anchors = anchors.narrow_to(featuremap)
image_shape2d = tf.shape(image)[2:] # h,w image_shape2d = tf.shape(image)[2:] # h,w
...@@ -216,7 +220,11 @@ class ResNetFPNModel(GeneralizedRCNN): ...@@ -216,7 +220,11 @@ class ResNetFPNModel(GeneralizedRCNN):
assert len(cfg.RPN.ANCHOR_SIZES) == len(cfg.FPN.ANCHOR_STRIDES) assert len(cfg.RPN.ANCHOR_SIZES) == len(cfg.FPN.ANCHOR_STRIDES)
image_shape2d = tf.shape(image)[2:] # h,w image_shape2d = tf.shape(image)[2:] # h,w
all_anchors_fpn = get_all_anchors_fpn() all_anchors_fpn = get_all_anchors_fpn(
strides=cfg.FPN.ANCHOR_STRIDES,
sizes=cfg.RPN.ANCHOR_SIZES,
ratios=cfg.RPN.ANCHOR_RATIOS,
max_size=cfg.PREPROC.MAX_SIZE)
multilevel_anchors = [RPNAnchors( multilevel_anchors = [RPNAnchors(
all_anchors_fpn[i], all_anchors_fpn[i],
inputs['anchor_labels_lvl{}'.format(i + 2)], inputs['anchor_labels_lvl{}'.format(i + 2)],
......
...@@ -25,6 +25,8 @@ except ImportError: ...@@ -25,6 +25,8 @@ except ImportError:
if __name__ == '__main__': if __name__ == '__main__':
import multiprocessing as mp
mp.set_start_method('spawn') # safer behavior & memory saving
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--load', help='load a model to start training from. Can overwrite BACKBONE.WEIGHTS') parser.add_argument('--load', help='load a model to start training from. Can overwrite BACKBONE.WEIGHTS')
parser.add_argument('--logdir', help='log directory', default='train_log/maskrcnn') parser.add_argument('--logdir', help='log directory', default='train_log/maskrcnn')
......
...@@ -59,7 +59,7 @@ class GPUUtilizationTracker(Callback): ...@@ -59,7 +59,7 @@ class GPUUtilizationTracker(Callback):
self._stop_evt = mp.Event() self._stop_evt = mp.Event()
self._queue = mp.Queue() self._queue = mp.Queue()
self._proc = mp.Process(target=self.worker, args=( self._proc = mp.Process(target=self.worker, args=(
self._evt, self._queue, self._stop_evt)) self._evt, self._queue, self._stop_evt, self._devices))
ensure_proc_terminate(self._proc) ensure_proc_terminate(self._proc)
start_proc_mask_signal(self._proc) start_proc_mask_signal(self._proc)
...@@ -96,9 +96,14 @@ class GPUUtilizationTracker(Callback): ...@@ -96,9 +96,14 @@ class GPUUtilizationTracker(Callback):
self._evt.set() self._evt.set()
self._proc.terminate() self._proc.terminate()
def worker(self, evt, rst_queue, stop_evt): @staticmethod
def worker(evt, rst_queue, stop_evt, devices):
"""
Args:
devices (list[int])
"""
with NVMLContext() as ctx: with NVMLContext() as ctx:
devices = [ctx.device(i) for i in self._devices] devices = [ctx.device(i) for i in devices]
while True: while True:
try: try:
evt.wait() # start epoch evt.wait() # start epoch
...@@ -106,7 +111,7 @@ class GPUUtilizationTracker(Callback): ...@@ -106,7 +111,7 @@ class GPUUtilizationTracker(Callback):
if stop_evt.is_set(): # or on exit if stop_evt.is_set(): # or on exit
return return
stats = np.zeros((len(self._devices),), dtype='f4') stats = np.zeros((len(devices),), dtype='f4')
cnt = 0 cnt = 0
while True: while True:
time.sleep(1) time.sleep(1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment