Commit 59829770 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use small buffer size in data loader to reduce memory usage (#1164,#1152,#1111)

parent bae419ca
......@@ -93,7 +93,7 @@ _C.DATA.NUM_CATEGORY = 0 # without the background class (e.g., 80 for COCO)
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1]
_C.DATA.ABSOLUTE_COORD = True
_C.DATA.NUM_WORKERS = 5 # number of data loading workers
_C.DATA.NUM_WORKERS = 5 # number of data loading workers. set to 0 to disable parallel data loading
# backbone ----------------------
_C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz
......
......@@ -8,7 +8,7 @@ from tabulate import tabulate
from termcolor import colored
from tensorpack.dataflow import (
DataFromList, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData, TestDataSpeed, imgaug)
DataFromList, MapDataComponent, MapData, MultiProcessMapDataZMQ, MultiThreadMapData, TestDataSpeed, imgaug)
from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once, memoized
......@@ -368,11 +368,15 @@ def get_train_dataflow():
# tpviz.interactive_imshow(viz)
return ret
if cfg.TRAINER == 'horovod':
ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess)
# MPI does not like fork()
if cfg.DATA.NUM_WORKERS > 0:
buffer_size = cfg.DATA.NUM_WORKERS * 20
if cfg.TRAINER == 'horovod':
ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
# MPI does not like fork()
else:
ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
else:
ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess)
ds = MapData(ds, preprocess)
return ds
......
......@@ -154,6 +154,7 @@ class MultiThreadMapData(_ParallelMapData):
strict (bool): use "strict mode", see notes above.
"""
super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
assert nr_thread > 0, nr_thread
self._strict = strict
self.nr_thread = nr_thread
......@@ -259,6 +260,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
"""
_ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self)
assert nr_proc > 0, nr_proc
self.nr_proc = nr_proc
self.map_func = map_func
self._strict = strict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment