Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e027bc2a
Commit
e027bc2a
authored
Nov 16, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] Circumvent TF bug in EvalCallback
parent
96f8f96e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
1 deletion
+4
-1
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+4
-1
No files found.
examples/FasterRCNN/train.py
View file @
e027bc2a
...
@@ -407,8 +407,11 @@ class EvalCallback(Callback):
...
@@ -407,8 +407,11 @@ class EvalCallback(Callback):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
num_gpu
=
cfg
.
TRAIN
.
NUM_GPUS
num_gpu
=
cfg
.
TRAIN
.
NUM_GPUS
if
cfg
.
TRAINER
==
'replicated'
:
if
cfg
.
TRAINER
==
'replicated'
:
# TF bug in version 1.11, 1.12: https://github.com/tensorflow/tensorflow/issues/22750
buggy_tf
=
get_tf_version_tuple
()
in
[(
1
,
11
),
(
1
,
12
)]
# Use two predictor threads per GPU to get better throughput
# Use two predictor threads per GPU to get better throughput
self
.
num_predictor
=
num_gpu
*
2
self
.
num_predictor
=
num_gpu
if
buggy_tf
else
num_gpu
*
2
self
.
predictors
=
[
self
.
_build_coco_predictor
(
k
%
num_gpu
)
for
k
in
range
(
self
.
num_predictor
)]
self
.
predictors
=
[
self
.
_build_coco_predictor
(
k
%
num_gpu
)
for
k
in
range
(
self
.
num_predictor
)]
self
.
dataflows
=
[
get_eval_dataflow
(
shard
=
k
,
num_shards
=
self
.
num_predictor
)
self
.
dataflows
=
[
get_eval_dataflow
(
shard
=
k
,
num_shards
=
self
.
num_predictor
)
for
k
in
range
(
self
.
num_predictor
)]
for
k
in
range
(
self
.
num_predictor
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment