Commit e027bc2a authored by Yuxin Wu's avatar Yuxin Wu

[MaskRCNN] Circumvent TF bug in EvalCallback

parent 96f8f96e
......@@ -407,8 +407,11 @@ class EvalCallback(Callback):
def _setup_graph(self):
num_gpu = cfg.TRAIN.NUM_GPUS
if cfg.TRAINER == 'replicated':
# TF bug in version 1.11, 1.12: https://github.com/tensorflow/tensorflow/issues/22750
buggy_tf = get_tf_version_tuple() in [(1, 11), (1, 12)]
# Use two predictor threads per GPU to get better throughput
self.num_predictor = num_gpu * 2
self.num_predictor = num_gpu if buggy_tf else num_gpu * 2
self.predictors = [self._build_coco_predictor(k % num_gpu) for k in range(self.num_predictor)]
self.dataflows = [get_eval_dataflow(shard=k, num_shards=self.num_predictor)
for k in range(self.num_predictor)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment