Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
49e04ffa
Commit
49e04ffa
authored
Aug 22, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] use dict as input
parent
787be08e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
36 deletions
+51
-36
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+20
-19
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+15
-16
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+14
-1
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+2
-0
No files found.
examples/FasterRCNN/data.py
View file @
49e04ffa
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
import
copy
import
copy
import
itertools
from
tensorpack.utils.argtools
import
memoized
,
log_once
from
tensorpack.utils.argtools
import
memoized
,
log_once
from
tensorpack.dataflow
import
(
from
tensorpack.dataflow
import
(
...
@@ -282,11 +281,11 @@ def get_train_dataflow():
...
@@ -282,11 +281,11 @@ def get_train_dataflow():
If MODE_MASK, gt_masks: (N, h, w)
If MODE_MASK, gt_masks: (N, h, w)
"""
"""
img
s
=
COCODetection
.
load_many
(
roidb
s
=
COCODetection
.
load_many
(
cfg
.
DATA
.
BASEDIR
,
cfg
.
DATA
.
TRAIN
,
add_gt
=
True
,
add_mask
=
cfg
.
MODE_MASK
)
cfg
.
DATA
.
BASEDIR
,
cfg
.
DATA
.
TRAIN
,
add_gt
=
True
,
add_mask
=
cfg
.
MODE_MASK
)
"""
"""
To train on your own data, change this to your loader.
To train on your own data, change this to your loader.
Produce "
img
s" as a list of dict, in the dict the following keys are needed for training:
Produce "
roidb
s" as a list of dict, in the dict the following keys are needed for training:
height, width: integer
height, width: integer
file_name: str, full path to the image
file_name: str, full path to the image
boxes: numpy array of kx4 floats
boxes: numpy array of kx4 floats
...
@@ -304,19 +303,19 @@ def get_train_dataflow():
...
@@ -304,19 +303,19 @@ def get_train_dataflow():
# Valid training images should have at least one fg box.
# Valid training images should have at least one fg box.
# But this filter shall not be applied for testing.
# But this filter shall not be applied for testing.
num
=
len
(
img
s
)
num
=
len
(
roidb
s
)
imgs
=
list
(
filter
(
lambda
img
:
len
(
img
[
'boxes'
][
img
[
'is_crowd'
]
==
0
])
>
0
,
img
s
))
roidbs
=
list
(
filter
(
lambda
img
:
len
(
img
[
'boxes'
][
img
[
'is_crowd'
]
==
0
])
>
0
,
roidb
s
))
logger
.
info
(
"Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}"
.
format
(
logger
.
info
(
"Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}"
.
format
(
num
-
len
(
imgs
),
len
(
img
s
)))
num
-
len
(
roidbs
),
len
(
roidb
s
)))
ds
=
DataFromList
(
img
s
,
shuffle
=
True
)
ds
=
DataFromList
(
roidb
s
,
shuffle
=
True
)
aug
=
imgaug
.
AugmentorList
(
aug
=
imgaug
.
AugmentorList
(
[
CustomResize
(
cfg
.
PREPROC
.
SHORT_EDGE_SIZE
,
cfg
.
PREPROC
.
MAX_SIZE
),
[
CustomResize
(
cfg
.
PREPROC
.
SHORT_EDGE_SIZE
,
cfg
.
PREPROC
.
MAX_SIZE
),
imgaug
.
Flip
(
horiz
=
True
)])
imgaug
.
Flip
(
horiz
=
True
)])
def
preprocess
(
img
):
def
preprocess
(
roidb
):
fname
,
boxes
,
klass
,
is_crowd
=
img
[
'file_name'
],
img
[
'boxes'
],
img
[
'class'
],
img
[
'is_crowd'
]
fname
,
boxes
,
klass
,
is_crowd
=
roidb
[
'file_name'
],
roidb
[
'boxes'
],
roidb
[
'class'
],
roidb
[
'is_crowd'
]
boxes
=
np
.
copy
(
boxes
)
boxes
=
np
.
copy
(
boxes
)
im
=
cv2
.
imread
(
fname
,
cv2
.
IMREAD_COLOR
)
im
=
cv2
.
imread
(
fname
,
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
,
fname
assert
im
is
not
None
,
fname
...
@@ -331,29 +330,31 @@ def get_train_dataflow():
...
@@ -331,29 +330,31 @@ def get_train_dataflow():
boxes
=
point8_to_box
(
points
)
boxes
=
point8_to_box
(
points
)
assert
np
.
min
(
np_area
(
boxes
))
>
0
,
"Some boxes have zero area!"
assert
np
.
min
(
np_area
(
boxes
))
>
0
,
"Some boxes have zero area!"
ret
=
{
'image'
:
im
}
# rpn anchor:
# rpn anchor:
try
:
try
:
if
cfg
.
MODE_FPN
:
if
cfg
.
MODE_FPN
:
multilevel_anchor_inputs
=
get_multilevel_rpn_anchor_input
(
im
,
boxes
,
is_crowd
)
multilevel_anchor_inputs
=
get_multilevel_rpn_anchor_input
(
im
,
boxes
,
is_crowd
)
anchor_inputs
=
itertools
.
chain
.
from_iterable
(
multilevel_anchor_inputs
)
for
i
,
(
anchor_labels
,
anchor_boxes
)
in
enumerate
(
multilevel_anchor_inputs
):
ret
[
'anchor_labels_lvl{}'
.
format
(
i
+
2
)]
=
anchor_labels
ret
[
'anchor_boxes_lvl{}'
.
format
(
i
+
2
)]
=
anchor_boxes
else
:
else
:
# anchor_labels, anchor_boxes
# anchor_labels, anchor_boxes
anchor_inputs
=
get_rpn_anchor_input
(
im
,
boxes
,
is_crowd
)
ret
[
'anchor_labels'
],
ret
[
'anchor_boxes'
]
=
get_rpn_anchor_input
(
im
,
boxes
,
is_crowd
)
assert
len
(
anchor_inputs
)
==
2
boxes
=
boxes
[
is_crowd
==
0
]
# skip crowd boxes in training target
boxes
=
boxes
[
is_crowd
==
0
]
# skip crowd boxes in training target
klass
=
klass
[
is_crowd
==
0
]
klass
=
klass
[
is_crowd
==
0
]
ret
[
'gt_boxes'
]
=
boxes
ret
[
'gt_labels'
]
=
klass
if
not
len
(
boxes
):
if
not
len
(
boxes
):
raise
MalformedData
(
"No valid gt_boxes!"
)
raise
MalformedData
(
"No valid gt_boxes!"
)
except
MalformedData
as
e
:
except
MalformedData
as
e
:
log_once
(
"Input {} is filtered for training: {}"
.
format
(
fname
,
str
(
e
)),
'warn'
)
log_once
(
"Input {} is filtered for training: {}"
.
format
(
fname
,
str
(
e
)),
'warn'
)
return
None
return
None
ret
=
[
im
]
+
list
(
anchor_inputs
)
+
[
boxes
,
klass
]
if
cfg
.
MODE_MASK
:
if
cfg
.
MODE_MASK
:
# augmentation will modify the polys in-place
# augmentation will modify the polys in-place
segmentation
=
copy
.
deepcopy
(
img
[
'segmentation'
])
segmentation
=
copy
.
deepcopy
(
roidb
[
'segmentation'
])
segmentation
=
[
segmentation
[
k
]
for
k
in
range
(
len
(
segmentation
))
if
not
is_crowd
[
k
]]
segmentation
=
[
segmentation
[
k
]
for
k
in
range
(
len
(
segmentation
))
if
not
is_crowd
[
k
]]
assert
len
(
segmentation
)
==
len
(
boxes
)
assert
len
(
segmentation
)
==
len
(
boxes
)
...
@@ -364,7 +365,7 @@ def get_train_dataflow():
...
@@ -364,7 +365,7 @@ def get_train_dataflow():
polys
=
[
aug
.
augment_coords
(
p
,
params
)
for
p
in
polys
]
polys
=
[
aug
.
augment_coords
(
p
,
params
)
for
p
in
polys
]
masks
.
append
(
segmentation_to_mask
(
polys
,
im
.
shape
[
0
],
im
.
shape
[
1
]))
masks
.
append
(
segmentation_to_mask
(
polys
,
im
.
shape
[
0
],
im
.
shape
[
1
]))
masks
=
np
.
asarray
(
masks
,
dtype
=
'uint8'
)
# values in {0, 1}
masks
=
np
.
asarray
(
masks
,
dtype
=
'uint8'
)
# values in {0, 1}
ret
.
append
(
masks
)
ret
[
'gt_masks'
]
=
masks
# from viz import draw_annotation, draw_mask
# from viz import draw_annotation, draw_mask
# viz = draw_annotation(im, boxes, klass)
# viz = draw_annotation(im, boxes, klass)
...
@@ -386,13 +387,13 @@ def get_eval_dataflow(shard=0, num_shards=1):
...
@@ -386,13 +387,13 @@ def get_eval_dataflow(shard=0, num_shards=1):
Args:
Args:
shard, num_shards: to get subset of evaluation data
shard, num_shards: to get subset of evaluation data
"""
"""
img
s
=
COCODetection
.
load_many
(
cfg
.
DATA
.
BASEDIR
,
cfg
.
DATA
.
VAL
,
add_gt
=
False
)
roidb
s
=
COCODetection
.
load_many
(
cfg
.
DATA
.
BASEDIR
,
cfg
.
DATA
.
VAL
,
add_gt
=
False
)
num_imgs
=
len
(
img
s
)
num_imgs
=
len
(
roidb
s
)
img_per_shard
=
num_imgs
//
num_shards
img_per_shard
=
num_imgs
//
num_shards
img_range
=
(
shard
*
img_per_shard
,
(
shard
+
1
)
*
img_per_shard
if
shard
+
1
<
num_shards
else
num_imgs
)
img_range
=
(
shard
*
img_per_shard
,
(
shard
+
1
)
*
img_per_shard
if
shard
+
1
<
num_shards
else
num_imgs
)
# no filter for training
# no filter for training
ds
=
DataFromListOfDict
(
img
s
[
img_range
[
0
]:
img_range
[
1
]],
[
'file_name'
,
'id'
])
ds
=
DataFromListOfDict
(
roidb
s
[
img_range
[
0
]:
img_range
[
1
]],
[
'file_name'
,
'id'
])
def
f
(
fname
):
def
f
(
fname
):
im
=
cv2
.
imread
(
fname
,
cv2
.
IMREAD_COLOR
)
im
=
cv2
.
imread
(
fname
,
cv2
.
IMREAD_COLOR
)
...
...
examples/FasterRCNN/train.py
View file @
49e04ffa
...
@@ -160,17 +160,14 @@ class ResNetC4Model(DetectionModel):
...
@@ -160,17 +160,14 @@ class ResNetC4Model(DetectionModel):
return
ret
return
ret
def
build_graph
(
self
,
*
inputs
):
def
build_graph
(
self
,
*
inputs
):
inputs
=
dict
(
zip
(
self
.
input_names
,
inputs
))
is_training
=
get_current_tower_context
()
.
is_training
is_training
=
get_current_tower_context
()
.
is_training
if
cfg
.
MODE_MASK
:
image
=
self
.
preprocess
(
inputs
[
'image'
])
# 1CHW
image
,
anchor_labels
,
anchor_boxes
,
gt_boxes
,
gt_labels
,
gt_masks
=
inputs
else
:
image
,
anchor_labels
,
anchor_boxes
,
gt_boxes
,
gt_labels
=
inputs
image
=
self
.
preprocess
(
image
)
# 1CHW
featuremap
=
resnet_c4_backbone
(
image
,
cfg
.
BACKBONE
.
RESNET_NUM_BLOCK
[:
3
])
featuremap
=
resnet_c4_backbone
(
image
,
cfg
.
BACKBONE
.
RESNET_NUM_BLOCK
[:
3
])
rpn_label_logits
,
rpn_box_logits
=
rpn_head
(
'rpn'
,
featuremap
,
cfg
.
RPN
.
HEAD_DIM
,
cfg
.
RPN
.
NUM_ANCHOR
)
rpn_label_logits
,
rpn_box_logits
=
rpn_head
(
'rpn'
,
featuremap
,
cfg
.
RPN
.
HEAD_DIM
,
cfg
.
RPN
.
NUM_ANCHOR
)
anchors
=
RPNAnchors
(
get_all_anchors
(),
anchor_labels
,
anchor_boxes
)
anchors
=
RPNAnchors
(
get_all_anchors
(),
inputs
[
'anchor_labels'
],
inputs
[
'anchor_boxes'
]
)
anchors
=
anchors
.
narrow_to
(
featuremap
)
anchors
=
anchors
.
narrow_to
(
featuremap
)
image_shape2d
=
tf
.
shape
(
image
)[
2
:]
# h,w
image_shape2d
=
tf
.
shape
(
image
)[
2
:]
# h,w
...
@@ -182,6 +179,7 @@ class ResNetC4Model(DetectionModel):
...
@@ -182,6 +179,7 @@ class ResNetC4Model(DetectionModel):
cfg
.
RPN
.
TRAIN_PRE_NMS_TOPK
if
is_training
else
cfg
.
RPN
.
TEST_PRE_NMS_TOPK
,
cfg
.
RPN
.
TRAIN_PRE_NMS_TOPK
if
is_training
else
cfg
.
RPN
.
TEST_PRE_NMS_TOPK
,
cfg
.
RPN
.
TRAIN_POST_NMS_TOPK
if
is_training
else
cfg
.
RPN
.
TEST_POST_NMS_TOPK
)
cfg
.
RPN
.
TRAIN_POST_NMS_TOPK
if
is_training
else
cfg
.
RPN
.
TEST_POST_NMS_TOPK
)
gt_boxes
,
gt_labels
=
inputs
[
'gt_boxes'
],
inputs
[
'gt_labels'
]
if
is_training
:
if
is_training
:
# sample proposal boxes in training
# sample proposal boxes in training
rcnn_boxes
,
rcnn_labels
,
fg_inds_wrt_gt
=
sample_fast_rcnn_targets
(
rcnn_boxes
,
rcnn_labels
,
fg_inds_wrt_gt
=
sample_fast_rcnn_targets
(
...
@@ -224,7 +222,7 @@ class ResNetC4Model(DetectionModel):
...
@@ -224,7 +222,7 @@ class ResNetC4Model(DetectionModel):
'maskrcnn'
,
fg_feature
,
cfg
.
DATA
.
NUM_CATEGORY
,
num_convs
=
0
)
# #fg x #cat x 14x14
'maskrcnn'
,
fg_feature
,
cfg
.
DATA
.
NUM_CATEGORY
,
num_convs
=
0
)
# #fg x #cat x 14x14
target_masks_for_fg
=
crop_and_resize
(
target_masks_for_fg
=
crop_and_resize
(
tf
.
expand_dims
(
gt_masks
,
1
),
tf
.
expand_dims
(
inputs
[
'gt_masks'
]
,
1
),
fg_sampled_boxes
,
fg_sampled_boxes
,
fg_inds_wrt_gt
,
14
,
fg_inds_wrt_gt
,
14
,
pad_border
=
False
)
# nfg x 1x14x14
pad_border
=
False
)
# nfg x 1x14x14
...
@@ -293,18 +291,18 @@ class ResNetFPNModel(DetectionModel):
...
@@ -293,18 +291,18 @@ class ResNetFPNModel(DetectionModel):
anchors
[
i
]
=
anchors
[
i
]
.
narrow_to
(
p23456
[
i
])
anchors
[
i
]
=
anchors
[
i
]
.
narrow_to
(
p23456
[
i
])
def
build_graph
(
self
,
*
inputs
):
def
build_graph
(
self
,
*
inputs
):
inputs
=
dict
(
zip
(
self
.
input_names
,
inputs
))
num_fpn_level
=
len
(
cfg
.
FPN
.
ANCHOR_STRIDES
)
num_fpn_level
=
len
(
cfg
.
FPN
.
ANCHOR_STRIDES
)
assert
len
(
cfg
.
RPN
.
ANCHOR_SIZES
)
==
num_fpn_level
assert
len
(
cfg
.
RPN
.
ANCHOR_SIZES
)
==
num_fpn_level
is_training
=
get_current_tower_context
()
.
is_training
is_training
=
get_current_tower_context
()
.
is_training
image
=
inputs
[
0
]
input_anchors
=
inputs
[
1
:
1
+
2
*
num_fpn_level
]
multilevel_anchors
=
[
RPNAnchors
(
*
args
)
for
args
in
zip
(
get_all_anchors_fpn
(),
input_anchors
[
0
::
2
],
input_anchors
[
1
::
2
])]
gt_boxes
,
gt_labels
=
inputs
[
11
],
inputs
[
12
]
if
cfg
.
MODE_MASK
:
gt_masks
=
inputs
[
-
1
]
image
=
self
.
preprocess
(
image
)
# 1CHW
all_anchors_fpn
=
get_all_anchors_fpn
()
multilevel_anchors
=
[
RPNAnchors
(
all_anchors_fpn
[
i
],
inputs
[
'anchor_labels_lvl{}'
.
format
(
i
+
2
)],
inputs
[
'anchor_boxes_lvl{}'
.
format
(
i
+
2
)])
for
i
in
range
(
len
(
all_anchors_fpn
))]
image
=
self
.
preprocess
(
inputs
[
'image'
])
# 1CHW
image_shape2d
=
tf
.
shape
(
image
)[
2
:]
# h,w
image_shape2d
=
tf
.
shape
(
image
)[
2
:]
# h,w
c2345
=
resnet_fpn_backbone
(
image
,
cfg
.
BACKBONE
.
RESNET_NUM_BLOCK
)
c2345
=
resnet_fpn_backbone
(
image
,
cfg
.
BACKBONE
.
RESNET_NUM_BLOCK
)
...
@@ -321,6 +319,7 @@ class ResNetFPNModel(DetectionModel):
...
@@ -321,6 +319,7 @@ class ResNetFPNModel(DetectionModel):
multilevel_anchors
,
multilevel_label_logits
,
multilevel_anchors
,
multilevel_label_logits
,
multilevel_box_logits
,
image_shape2d
)
multilevel_box_logits
,
image_shape2d
)
gt_boxes
,
gt_labels
=
inputs
[
'gt_boxes'
],
inputs
[
'gt_labels'
]
if
is_training
:
if
is_training
:
rcnn_boxes
,
rcnn_labels
,
fg_inds_wrt_gt
=
sample_fast_rcnn_targets
(
rcnn_boxes
,
rcnn_labels
,
fg_inds_wrt_gt
=
sample_fast_rcnn_targets
(
proposal_boxes
,
gt_boxes
,
gt_labels
)
proposal_boxes
,
gt_boxes
,
gt_labels
)
...
@@ -361,7 +360,7 @@ class ResNetFPNModel(DetectionModel):
...
@@ -361,7 +360,7 @@ class ResNetFPNModel(DetectionModel):
'maskrcnn'
,
roi_feature_maskrcnn
,
cfg
.
DATA
.
NUM_CATEGORY
)
# #fg x #cat x 28 x 28
'maskrcnn'
,
roi_feature_maskrcnn
,
cfg
.
DATA
.
NUM_CATEGORY
)
# #fg x #cat x 28 x 28
target_masks_for_fg
=
crop_and_resize
(
target_masks_for_fg
=
crop_and_resize
(
tf
.
expand_dims
(
gt_masks
,
1
),
tf
.
expand_dims
(
inputs
[
'gt_masks'
]
,
1
),
fg_sampled_boxes
,
fg_sampled_boxes
,
fg_inds_wrt_gt
,
28
,
fg_inds_wrt_gt
,
28
,
pad_border
=
False
)
# fg x 1x28x28
pad_border
=
False
)
# fg x 1x28x28
...
...
tensorpack/graph_builder/utils.py
View file @
49e04ffa
...
@@ -41,6 +41,10 @@ def _replace_global_by_local(kwargs):
...
@@ -41,6 +41,10 @@ def _replace_global_by_local(kwargs):
@
contextmanager
@
contextmanager
def
override_to_local_variable
(
enable
=
True
):
def
override_to_local_variable
(
enable
=
True
):
"""
Returns:
a context where all variables will be created as local.
"""
if
enable
:
if
enable
:
def
custom_getter
(
getter
,
name
,
*
args
,
**
kwargs
):
def
custom_getter
(
getter
,
name
,
*
args
,
**
kwargs
):
...
@@ -55,7 +59,16 @@ def override_to_local_variable(enable=True):
...
@@ -55,7 +59,16 @@ def override_to_local_variable(enable=True):
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L192-L218
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L192-L218
class
LeastLoadedDeviceSetter
(
object
):
class
LeastLoadedDeviceSetter
(
object
):
""" Helper class to assign variables on the least loaded ps-device."""
"""
Helper class to assign variables on the least loaded ps-device.
Usage:
.. code-block:: python
with tf.device(LeastLoadedDeviceSetter(...)):
...
"""
def
__init__
(
self
,
worker_device
,
ps_devices
):
def
__init__
(
self
,
worker_device
,
ps_devices
):
"""
"""
Args:
Args:
...
...
tensorpack/input_source/input_source.py
View file @
49e04ffa
...
@@ -46,6 +46,8 @@ def _make_feeds(placeholders, datapoint):
...
@@ -46,6 +46,8 @@ def _make_feeds(placeholders, datapoint):
elif
isinstance
(
datapoint
,
dict
):
elif
isinstance
(
datapoint
,
dict
):
ret
=
{
p
:
datapoint
[
p
.
op
.
name
]
for
p
in
placeholders
}
ret
=
{
p
:
datapoint
[
p
.
op
.
name
]
for
p
in
placeholders
}
return
ret
return
ret
else
:
raise
TypeError
(
"Got a datapoint of type {}!"
.
format
(
type
(
datapoint
)))
class
PlaceholderInput
(
InputSource
):
class
PlaceholderInput
(
InputSource
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment