Commit 59829770 authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] use small buffer size in data loader to reduce memory usage (#1164,#1152,#1111)

parent bae419ca
...@@ -93,7 +93,7 @@ _C.DATA.NUM_CATEGORY = 0 # without the background class (e.g., 80 for COCO) ...@@ -93,7 +93,7 @@ _C.DATA.NUM_CATEGORY = 0 # without the background class (e.g., 80 for COCO)
_C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG". _C.DATA.CLASS_NAMES = [] # NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1] # whether the coordinates in the annotations are absolute pixel values, or a relative value in [0, 1]
_C.DATA.ABSOLUTE_COORD = True _C.DATA.ABSOLUTE_COORD = True
_C.DATA.NUM_WORKERS = 5 # number of data loading workers _C.DATA.NUM_WORKERS = 5 # number of data loading workers. set to 0 to disable parallel data loading
# backbone ---------------------- # backbone ----------------------
_C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz _C.BACKBONE.WEIGHTS = '' # /path/to/weights.npz
......
...@@ -8,7 +8,7 @@ from tabulate import tabulate ...@@ -8,7 +8,7 @@ from tabulate import tabulate
from termcolor import colored from termcolor import colored
from tensorpack.dataflow import ( from tensorpack.dataflow import (
DataFromList, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData, TestDataSpeed, imgaug) DataFromList, MapDataComponent, MapData, MultiProcessMapDataZMQ, MultiThreadMapData, TestDataSpeed, imgaug)
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once, memoized from tensorpack.utils.argtools import log_once, memoized
...@@ -368,11 +368,15 @@ def get_train_dataflow(): ...@@ -368,11 +368,15 @@ def get_train_dataflow():
# tpviz.interactive_imshow(viz) # tpviz.interactive_imshow(viz)
return ret return ret
if cfg.DATA.NUM_WORKERS > 0:
buffer_size = cfg.DATA.NUM_WORKERS * 20
if cfg.TRAINER == 'horovod': if cfg.TRAINER == 'horovod':
ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess) ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
# MPI does not like fork() # MPI does not like fork()
else: else:
ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess) ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
else:
ds = MapData(ds, preprocess)
return ds return ds
......
...@@ -154,6 +154,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -154,6 +154,7 @@ class MultiThreadMapData(_ParallelMapData):
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
""" """
super(MultiThreadMapData, self).__init__(ds, buffer_size, strict) super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
assert nr_thread > 0, nr_thread
self._strict = strict self._strict = strict
self.nr_thread = nr_thread self.nr_thread = nr_thread
...@@ -259,6 +260,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -259,6 +260,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
""" """
_ParallelMapData.__init__(self, ds, buffer_size, strict) _ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self) _MultiProcessZMQDataFlow.__init__(self)
assert nr_proc > 0, nr_proc
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.map_func = map_func self.map_func = map_func
self._strict = strict self._strict = strict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment