Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
35599506
Commit
35599506
authored
Sep 20, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] fallback to tf.gather_nd for older TF versions
parent
2aa760b1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
5 deletions
+17
-5
examples/FasterRCNN/modeling/model_frcnn.py
examples/FasterRCNN/modeling/model_frcnn.py
+7
-2
examples/FasterRCNN/modeling/model_mrcnn.py
examples/FasterRCNN/modeling/model_mrcnn.py
+8
-1
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+2
-2
No files found.
examples/FasterRCNN/modeling/model_frcnn.py
View file @
35599506
...
@@ -151,8 +151,13 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
...
@@ -151,8 +151,13 @@ def fastrcnn_losses(labels, label_logits, fg_boxes, fg_box_logits):
num_fg
=
tf
.
size
(
fg_inds
,
out_type
=
tf
.
int64
)
num_fg
=
tf
.
size
(
fg_inds
,
out_type
=
tf
.
int64
)
empty_fg
=
tf
.
equal
(
num_fg
,
0
)
empty_fg
=
tf
.
equal
(
num_fg
,
0
)
if
int
(
fg_box_logits
.
shape
[
1
])
>
1
:
if
int
(
fg_box_logits
.
shape
[
1
])
>
1
:
fg_box_logits
=
tf
.
batch_gather
(
fg_box_logits
,
tf
.
expand_dims
(
fg_labels
,
axis
=
1
))
if
get_tf_version_tuple
()
>=
(
1
,
14
):
fg_box_logits
=
tf
.
reshape
(
fg_box_logits
,
[
-
1
,
4
])
fg_labels
=
tf
.
expand_dims
(
fg_labels
,
axis
=
1
)
# nfg x 1
fg_box_logits
=
tf
.
gather
(
fg_box_logits
,
fg_labels
,
batch_dims
=
1
)
else
:
indices
=
tf
.
stack
([
tf
.
range
(
num_fg
),
fg_labels
],
axis
=
1
)
# nfgx2
fg_box_logits
=
tf
.
gather_nd
(
fg_box_logits
,
indices
)
fg_box_logits
=
tf
.
reshape
(
fg_box_logits
,
[
-
1
,
4
])
# nfg x 4
with
tf
.
name_scope
(
'label_metrics'
),
tf
.
device
(
'/cpu:0'
):
with
tf
.
name_scope
(
'label_metrics'
),
tf
.
device
(
'/cpu:0'
):
prediction
=
tf
.
argmax
(
label_logits
,
axis
=
1
,
name
=
'label_prediction'
)
prediction
=
tf
.
argmax
(
label_logits
,
axis
=
1
,
name
=
'label_prediction'
)
...
...
examples/FasterRCNN/modeling/model_mrcnn.py
View file @
35599506
...
@@ -20,7 +20,14 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
...
@@ -20,7 +20,14 @@ def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
fg_labels: #fg, in 1~#class, int64
fg_labels: #fg, in 1~#class, int64
fg_target_masks: #fgxhxw, float32
fg_target_masks: #fgxhxw, float32
"""
"""
mask_logits
=
tf
.
batch_gather
(
mask_logits
,
tf
.
reshape
(
fg_labels
,
[
-
1
,
1
])
-
1
)
if
get_tf_version_tuple
()
>=
(
1
,
14
):
mask_logits
=
tf
.
gather
(
mask_logits
,
tf
.
reshape
(
fg_labels
-
1
,
[
-
1
,
1
]),
batch_dims
=
1
)
else
:
indices
=
tf
.
stack
([
tf
.
range
(
tf
.
size
(
fg_labels
,
out_type
=
tf
.
int64
)),
fg_labels
-
1
],
axis
=
1
)
# #fgx2
mask_logits
=
tf
.
gather_nd
(
mask_logits
,
indices
)
# #fg x h x w
mask_logits
=
tf
.
squeeze
(
mask_logits
,
axis
=
1
)
mask_logits
=
tf
.
squeeze
(
mask_logits
,
axis
=
1
)
mask_probs
=
tf
.
sigmoid
(
mask_logits
)
mask_probs
=
tf
.
sigmoid
(
mask_logits
)
...
...
tensorpack/tfutils/varmanip.py
View file @
35599506
...
@@ -76,12 +76,12 @@ class SessionUpdate(object):
...
@@ -76,12 +76,12 @@ class SessionUpdate(object):
if
np
.
prod
(
varshape
)
!=
np
.
prod
(
value
.
shape
):
if
np
.
prod
(
varshape
)
!=
np
.
prod
(
value
.
shape
):
if
ignore_mismatch
:
if
ignore_mismatch
:
logger
.
warn
(
logger
.
warn
(
"Cannot load a
tensor of shape {} into the
variable '{}' whose shape is {}."
.
format
(
"Cannot load a
n array of shape {} into
variable '{}' whose shape is {}."
.
format
(
value
.
shape
,
name
,
varshape
))
value
.
shape
,
name
,
varshape
))
return
None
return
None
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Trying to load a
tensor of shape {} into the
variable '{}' whose shape is {}."
.
format
(
"Trying to load a
n array of shape {} into
variable '{}' whose shape is {}."
.
format
(
value
.
shape
,
name
,
varshape
))
value
.
shape
,
name
,
varshape
))
# TODO only allow reshape when shape different by empty axis
# TODO only allow reshape when shape different by empty axis
logger
.
warn
(
"The tensor is reshaped from {} to {} when assigned to '{}'"
.
format
(
logger
.
warn
(
"The tensor is reshaped from {} to {} when assigned to '{}'"
.
format
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment