Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
99d99e7d
Commit
99d99e7d
authored
Nov 09, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[FasterRCNN] split some methods from build_graph
parent
c878fb1f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
8 deletions
+17
-8
examples/FasterRCNN/basemodel.py
examples/FasterRCNN/basemodel.py
+2
-2
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+15
-6
No files found.
examples/FasterRCNN/basemodel.py
View file @
99d99e7d
...
...
@@ -39,7 +39,7 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity):
data_format
=
get_arg_scope
()[
'Conv2D'
][
'data_format'
]
n_in
=
l
.
get_shape
()
.
as_list
()[
1
if
data_format
==
'NCHW'
else
3
]
if
n_in
!=
n_out
:
# change dimension when channel is not the same
if
stride
==
2
and
'group3'
not
in
tf
.
get_variable_scope
()
.
name
:
if
stride
==
2
:
l
=
l
[:,
:,
:
-
1
,
:
-
1
]
return
Conv2D
(
'convshortcut'
,
l
,
n_out
,
1
,
stride
=
stride
,
padding
=
'VALID'
,
nl
=
nl
)
...
...
@@ -53,7 +53,7 @@ def resnet_shortcut(l, n_out, stride, nl=tf.identity):
def
resnet_bottleneck
(
l
,
ch_out
,
stride
):
l
,
shortcut
=
l
,
l
l
=
Conv2D
(
'conv1'
,
l
,
ch_out
,
1
,
nl
=
BNReLU
)
if
stride
==
2
and
'group3'
not
in
tf
.
get_variable_scope
()
.
name
:
if
stride
==
2
:
l
=
tf
.
pad
(
l
,
[[
0
,
0
],
[
0
,
0
],
[
0
,
1
],
[
0
,
1
]])
l
=
Conv2D
(
'conv2'
,
l
,
ch_out
,
3
,
stride
=
2
,
nl
=
BNReLU
,
padding
=
'VALID'
)
else
:
...
...
examples/FasterRCNN/train.py
View file @
99d99e7d
...
...
@@ -60,11 +60,16 @@ class Model(ModelDesc):
InputDesc
(
tf
.
int64
,
(
None
,),
'gt_labels'
),
]
def
_build_graph
(
self
,
inputs
):
is_training
=
get_current_tower_context
()
.
is_training
image
,
anchor_labels
,
anchor_boxes
,
gt_boxes
,
gt_labels
=
inputs
def
_preprocess
(
self
,
image
):
image
=
tf
.
expand_dims
(
image
,
0
)
image
=
image_preprocess
(
image
,
bgr
=
True
)
return
tf
.
transpose
(
image
,
[
0
,
3
,
1
,
2
])
def
_get_anchors
(
self
,
image
):
"""
Returns:
FSxFSxNAx4 anchors,
"""
# FSxFSxNAx4 (FS=MAX_SIZE//ANCHOR_STRIDE)
with
tf
.
name_scope
(
'anchors'
):
all_anchors
=
tf
.
constant
(
get_all_anchors
(),
name
=
'all_anchors'
,
dtype
=
tf
.
float32
)
...
...
@@ -73,11 +78,15 @@ class Model(ModelDesc):
tf
.
shape
(
image
)[
1
]
//
config
.
ANCHOR_STRIDE
,
tf
.
shape
(
image
)[
2
]
//
config
.
ANCHOR_STRIDE
,
-
1
,
-
1
]),
name
=
'fm_anchors'
)
anchor_boxes_encoded
=
encode_bbox_target
(
anchor_boxes
,
fm_anchors
)
return
fm_anchors
image
=
image_preprocess
(
image
,
bgr
=
True
)
image
=
tf
.
transpose
(
image
,
[
0
,
3
,
1
,
2
])
def
_build_graph
(
self
,
inputs
):
is_training
=
get_current_tower_context
()
.
is_training
image
,
anchor_labels
,
anchor_boxes
,
gt_boxes
,
gt_labels
=
inputs
image
=
self
.
_preprocess
(
image
)
fm_anchors
=
self
.
_get_anchors
(
image
)
anchor_boxes_encoded
=
encode_bbox_target
(
anchor_boxes
,
fm_anchors
)
# resnet50
featuremap
=
pretrained_resnet_conv4
(
image
,
[
3
,
4
,
6
])
rpn_label_logits
,
rpn_box_logits
=
rpn_head
(
featuremap
,
1024
,
config
.
NR_ANCHOR
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment