Commit 4df295ca authored by Yuxin Wu's avatar Yuxin Wu

Don't sync horovod vars every epoch (slow in some situation).

parent 8156810d
......@@ -35,19 +35,31 @@ class COCODetection(DatasetSplit):
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"] # noqa
cfg.DATA.CLASS_NAMES = ["BG"] + class_names
def __init__(self, basedir, name):
def __init__(self, basedir, split):
"""
Args:
basedir (str): root to the dataset
name (str): the name of the split, e.g. "train2017"
basedir (str): root of the dataset which contains the subdirectories for each split and annotations
split (str): the name of the split, e.g. "train2017".
The split has to match an annotation file in "annotations/" and a directory of images.
Examples:
For a directory of this structure:
DIR/
annotations/
instances_XX.json
instances_YY.json
XX/
YY/
use `COCODetection(DIR, 'XX')` and `COCODetection(DIR, 'YY')`
"""
basedir = os.path.expanduser(basedir)
self.name = name
self._imgdir = os.path.realpath(os.path.join(
basedir, self._INSTANCE_TO_BASEDIR.get(name, name)))
basedir, self._INSTANCE_TO_BASEDIR.get(split, split)))
assert os.path.isdir(self._imgdir), "{} is not a directory!".format(self._imgdir)
annotation_file = os.path.join(
basedir, 'annotations/instances_{}.json'.format(name))
basedir, 'annotations/instances_{}.json'.format(split))
assert os.path.isfile(annotation_file), annotation_file
from pycocotools.coco import COCO
......
......@@ -144,7 +144,7 @@ if __name__ == '__main__':
callbacks=[
ModelSaver(),
StatMonitorParamSetter(
'learning_rate', 'measure', lambda x: x * 0.5, 0, 10)
'learning_rate', 'losses/measure', lambda x: x * 0.5, 0, 10)
],
session_init=SaverRestore(args.load) if args.load else None,
steps_per_epoch=500, max_epoch=400)
......@@ -325,3 +325,15 @@ class CallbackFactory(Callback):
def _after_train(self):
if self._cb_after_train:
self._cb_after_train(self)
def __str__(self):
strs = []
if self._cb_setup_graph is not None:
strs.append("setup_graph=" + str(self._cb_setup_graph))
if self._cb_before_train is not None:
strs.append("before_train=" + str(self._cb_before_train))
if self._cb_trigger is not None:
strs.append("trigger=" + str(self._cb_trigger))
if self._cb_after_train is not None:
strs.append("after_train=" + str(self._cb_after_train))
return "CallbackFactory({})".format(', '.join(strs))
......@@ -424,7 +424,8 @@ class HorovodTrainer(SingleCostTrainer):
# the op will be created later in initialize()
self.trainer._broadcast_op.run()
cb = CallbackFactory(trigger=broadcast).set_chief_only(False)
# TODO provide a way to sync manually
cb = CallbackFactory(before_train=broadcast).set_chief_only(False)
return [cb]
@HIDE_DOC
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment