Commit 2069cfdc authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] Allow cfg.DATA.TRAIN to be a string (#1276)

parent c5a47192
...@@ -221,6 +221,8 @@ def finalize_configs(is_training): ...@@ -221,6 +221,8 @@ def finalize_configs(is_training):
_C.freeze(False) # populate new keys now _C.freeze(False) # populate new keys now
if isinstance(_C.DATA.VAL, six.string_types): # support single string (the typical case) as well if isinstance(_C.DATA.VAL, six.string_types): # support single string (the typical case) as well
_C.DATA.VAL = (_C.DATA.VAL, ) _C.DATA.VAL = (_C.DATA.VAL, )
if isinstance(_C.DATA.TRAIN, six.string_types): # support single string
_C.DATA.TRAIN = (_C.DATA.TRAIN, )
assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN', 'None'], _C.BACKBONE.NORM assert _C.BACKBONE.NORM in ['FreezeBN', 'SyncBN', 'GN', 'None'], _C.BACKBONE.NORM
if _C.BACKBONE.NORM != 'FreezeBN': if _C.BACKBONE.NORM != 'FreezeBN':
......
...@@ -332,7 +332,6 @@ def get_train_dataflow(): ...@@ -332,7 +332,6 @@ def get_train_dataflow():
If MODE_MASK, gt_masks: (N, h, w) If MODE_MASK, gt_masks: (N, h, w)
""" """
roidbs = list(itertools.chain.from_iterable(DatasetRegistry.get(x).training_roidbs() for x in cfg.DATA.TRAIN)) roidbs = list(itertools.chain.from_iterable(DatasetRegistry.get(x).training_roidbs() for x in cfg.DATA.TRAIN))
print_class_histogram(roidbs) print_class_histogram(roidbs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment