Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
754e17fc
Commit
754e17fc
authored
Jan 05, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] some renames to avoid the name of "COCO"
parent
cc63dee7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
81 additions
and
79 deletions
+81
-79
examples/FasterRCNN/NOTES.md
examples/FasterRCNN/NOTES.md
+7
-6
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+2
-2
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+26
-1
examples/FasterRCNN/dataset.py
examples/FasterRCNN/dataset.py
+13
-32
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+14
-13
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+19
-25
No files found.
examples/FasterRCNN/NOTES.md
View file @
754e17fc
### File Structure
This is a minimal implementation that simply contains these files:
+
coco.py: load COCO data
+
data.py: prepare data for training
+
dataset.py: load and evaluate COCO dataset
+
data.py: prepare data for training
& inference
+
common.py: common data preparation utilities
+
basemodel.py: implement backbones
+
model_box.py: implement box-related symbolic functions
+
model_{fpn,rpn,frcnn,mrcnn,cascade}.py: implement FPN,RPN,Fast-/Mask-/Cascade-RCNN models.
+
train.py: main
training
script
+
train.py: main
entry
script
+
utils/: third-party helper functions
+
eval.py: evaluation utilities
+
viz.py: visualization utilities
...
...
@@ -16,9 +16,10 @@ This is a minimal implementation that simply contains these files:
Data:
1.
It's easy to train on your own data. Just replace
`COCODetection.load_many`
in
`data.py`
by your own loader.
Also remember to change
`DATA.NUM_CATEGORY`
and
`DATA.CLASS_NAMES`
in the config.
The current evaluation code is also COCO-specific, and you may need to change it to use your data and metrics.
1.
It's easy to train on your own data.
If your data is not in COCO format, you can just rewrite all the methods of
`DetectionDataset`
following its documents in
`dataset.py`
.
You'll implement the logic to load your dataset and evaluate predictions.
2.
You can easily add more augmentations such as rotation, but be careful how a box should be
augmented. The code now will always use the minimal axis-aligned bounding box of the 4 corners,
...
...
examples/FasterRCNN/config.py
View file @
754e17fc
...
...
@@ -80,14 +80,14 @@ _C.MODE_MASK = True # FasterRCNN or MaskRCNN
_C
.
MODE_FPN
=
False
# dataset -----------------------
_C
.
DATA
.
BASEDIR
=
'/path/to/your/
COCO
/DIR'
_C
.
DATA
.
BASEDIR
=
'/path/to/your/
DATA
/DIR'
# All TRAIN dataset will be concatenated for training.
_C
.
DATA
.
TRAIN
=
[
'train2014'
,
'valminusminival2014'
]
# i.e. trainval35k, AKA train2017
# Each VAL dataset will be evaluated separately (instead of concatenated)
_C
.
DATA
.
VAL
=
(
'minival2014'
,
)
# AKA val2017
# This two config will be populated later by the dataset loader:
_C
.
DATA
.
NUM_CATEGORY
=
0
# without the background class (e.g., 80 for COCO)
_C
.
DATA
.
CLASS_NAMES
=
[]
# NUM_CLASS (NUM_CATEGORY+1) strings, the first is "BG".
# For COCO, this list will be populated later by the COCO data loader.
# basemodel ----------------------
_C
.
BACKBONE
.
WEIGHTS
=
''
# /path/to/weights.npz
...
...
examples/FasterRCNN/data.py
View file @
754e17fc
...
...
@@ -4,6 +4,8 @@
import
copy
import
numpy
as
np
import
cv2
from
tabulate
import
tabulate
from
termcolor
import
colored
from
tensorpack.dataflow
import
(
DataFromList
,
MapDataComponent
,
MultiProcessMapDataZMQ
,
MultiThreadMapData
,
TestDataSpeed
,
imgaug
)
...
...
@@ -13,7 +15,7 @@ from tensorpack.utils.argtools import log_once, memoized
from
common
import
(
CustomResize
,
DataFromListOfDict
,
box_to_point8
,
filter_boxes_inside_shape
,
point8_to_box
,
segmentation_to_mask
)
from
config
import
config
as
cfg
from
coco
import
DetectionDataset
from
dataset
import
DetectionDataset
from
utils.generate_anchors
import
generate_anchors
from
utils.np_box_ops
import
area
as
np_area
from
utils.np_box_ops
import
ioa
as
np_ioa
...
...
@@ -46,6 +48,28 @@ class MalformedData(BaseException):
pass
def
print_class_histogram
(
roidbs
):
"""
Args:
roidbs (list[dict]): the same format as the output of `load_training_roidbs`.
"""
dataset
=
DetectionDataset
()
hist_bins
=
np
.
arange
(
dataset
.
num_classes
+
1
)
# Histogram of ground-truth objects
gt_hist
=
np
.
zeros
((
dataset
.
num_classes
,),
dtype
=
np
.
int
)
for
entry
in
roidbs
:
# filter crowd?
gt_inds
=
np
.
where
(
(
entry
[
'class'
]
>
0
)
&
(
entry
[
'is_crowd'
]
==
0
))[
0
]
gt_classes
=
entry
[
'class'
][
gt_inds
]
gt_hist
+=
np
.
histogram
(
gt_classes
,
bins
=
hist_bins
)[
0
]
data
=
[[
dataset
.
class_names
[
i
],
v
]
for
i
,
v
in
enumerate
(
gt_hist
)]
data
.
append
([
'total'
,
sum
([
x
[
1
]
for
x
in
data
])])
table
=
tabulate
(
data
,
headers
=
[
'class'
,
'#box'
],
tablefmt
=
'pipe'
)
logger
.
info
(
"Ground-Truth Boxes:
\n
"
+
colored
(
table
,
'cyan'
))
@
memoized
def
get_all_anchors
(
stride
=
None
,
sizes
=
None
):
"""
...
...
@@ -281,6 +305,7 @@ def get_train_dataflow():
"""
roidbs
=
DetectionDataset
()
.
load_training_roidbs
(
cfg
.
DATA
.
TRAIN
)
print_class_histogram
(
roidbs
)
# Valid training images should have at least one fg box.
# But this filter shall not be applied for testing.
...
...
examples/FasterRCNN/
coco
.py
→
examples/FasterRCNN/
dataset
.py
View file @
754e17fc
...
...
@@ -5,8 +5,6 @@ import numpy as np
import
os
import
tqdm
import
json
from
tabulate
import
tabulate
from
termcolor
import
colored
from
tensorpack.utils
import
logger
from
tensorpack.utils.argtools
import
log_once
...
...
@@ -29,6 +27,9 @@ class COCODetection(object):
Mapping from the incontinuous COCO category id to an id in [1, #category]
"""
class_names
=
[
"person"
,
"bicycle"
,
"car"
,
"motorcycle"
,
"airplane"
,
"bus"
,
"train"
,
"truck"
,
"boat"
,
"traffic light"
,
"fire hydrant"
,
"stop sign"
,
"parking meter"
,
"bench"
,
"bird"
,
"cat"
,
"dog"
,
"horse"
,
"sheep"
,
"cow"
,
"elephant"
,
"bear"
,
"zebra"
,
"giraffe"
,
"backpack"
,
"umbrella"
,
"handbag"
,
"tie"
,
"suitcase"
,
"frisbee"
,
"skis"
,
"snowboard"
,
"sports ball"
,
"kite"
,
"baseball bat"
,
"baseball glove"
,
"skateboard"
,
"surfboard"
,
"tennis racket"
,
"bottle"
,
"wine glass"
,
"cup"
,
"fork"
,
"knife"
,
"spoon"
,
"bowl"
,
"banana"
,
"apple"
,
"sandwich"
,
"orange"
,
"broccoli"
,
"carrot"
,
"hot dog"
,
"pizza"
,
"donut"
,
"cake"
,
"chair"
,
"couch"
,
"potted plant"
,
"bed"
,
"dining table"
,
"toilet"
,
"tv"
,
"laptop"
,
"mouse"
,
"remote"
,
"keyboard"
,
"cell phone"
,
"microwave"
,
"oven"
,
"toaster"
,
"sink"
,
"refrigerator"
,
"book"
,
"clock"
,
"vase"
,
"scissors"
,
"teddy bear"
,
"hair drier"
,
"toothbrush"
]
# noqa
def
__init__
(
self
,
basedir
,
name
):
self
.
name
=
name
self
.
_imgdir
=
os
.
path
.
realpath
(
os
.
path
.
join
(
...
...
@@ -182,15 +183,9 @@ class COCODetection(object):
class
DetectionDataset
(
object
):
"""
A singleton to load datasets, evaluate results, and provide metadata.
"""
_instance
=
None
def
__new__
(
cls
):
if
not
isinstance
(
cls
.
_instance
,
cls
):
cls
.
_instance
=
object
.
__new__
(
cls
)
return
cls
.
_instance
To use your own dataset that's not in COCO format, rewrite all methods of this class.
"""
def
__init__
(
self
):
"""
This function is responsible for setting the dataset-specific
...
...
@@ -198,8 +193,7 @@ class DetectionDataset(object):
"""
self
.
num_category
=
cfg
.
DATA
.
NUM_CATEGORY
=
80
self
.
num_classes
=
self
.
num_category
+
1
self
.
class_names
=
cfg
.
DATA
.
CLASS_NAMES
=
[
"BG"
,
"person"
,
"bicycle"
,
"car"
,
"motorcycle"
,
"airplane"
,
"bus"
,
"train"
,
"truck"
,
"boat"
,
"traffic light"
,
"fire hydrant"
,
"stop sign"
,
"parking meter"
,
"bench"
,
"bird"
,
"cat"
,
"dog"
,
"horse"
,
"sheep"
,
"cow"
,
"elephant"
,
"bear"
,
"zebra"
,
"giraffe"
,
"backpack"
,
"umbrella"
,
"handbag"
,
"tie"
,
"suitcase"
,
"frisbee"
,
"skis"
,
"snowboard"
,
"sports ball"
,
"kite"
,
"baseball bat"
,
"baseball glove"
,
"skateboard"
,
"surfboard"
,
"tennis racket"
,
"bottle"
,
"wine glass"
,
"cup"
,
"fork"
,
"knife"
,
"spoon"
,
"bowl"
,
"banana"
,
"apple"
,
"sandwich"
,
"orange"
,
"broccoli"
,
"carrot"
,
"hot dog"
,
"pizza"
,
"donut"
,
"cake"
,
"chair"
,
"couch"
,
"potted plant"
,
"bed"
,
"dining table"
,
"toilet"
,
"tv"
,
"laptop"
,
"mouse"
,
"remote"
,
"keyboard"
,
"cell phone"
,
"microwave"
,
"oven"
,
"toaster"
,
"sink"
,
"refrigerator"
,
"book"
,
"clock"
,
"vase"
,
"scissors"
,
"teddy bear"
,
"hair drier"
,
"toothbrush"
]
# noqa
self
.
class_names
=
cfg
.
DATA
.
CLASS_NAMES
=
[
"BG"
]
+
COCODetection
.
class_names
assert
len
(
self
.
class_names
)
==
self
.
num_classes
def
load_training_roidbs
(
self
,
names
):
...
...
@@ -284,29 +278,16 @@ class DetectionDataset(object):
else
:
return
{}
def
print_class_histogram
(
self
,
roidbs
):
"""
Args:
roidbs (list[dict]): the same format as the output of `load_training_roidbs`.
"""
hist_bins
=
np
.
arange
(
self
.
num_classes
+
1
)
# Histogram of ground-truth objects
gt_hist
=
np
.
zeros
((
self
.
num_classes
,),
dtype
=
np
.
int
)
for
entry
in
roidbs
:
# filter crowd?
gt_inds
=
np
.
where
(
(
entry
[
'class'
]
>
0
)
&
(
entry
[
'is_crowd'
]
==
0
))[
0
]
gt_classes
=
entry
[
'class'
][
gt_inds
]
gt_hist
+=
np
.
histogram
(
gt_classes
,
bins
=
hist_bins
)[
0
]
data
=
[[
self
.
class_names
[
i
],
v
]
for
i
,
v
in
enumerate
(
gt_hist
)]
data
.
append
([
'total'
,
sum
([
x
[
1
]
for
x
in
data
])])
table
=
tabulate
(
data
,
headers
=
[
'class'
,
'#box'
],
tablefmt
=
'pipe'
)
logger
.
info
(
"Ground-Truth Boxes:
\n
"
+
colored
(
table
,
'cyan'
))
# code for singleton:
_instance
=
None
def
__new__
(
cls
):
if
not
isinstance
(
cls
.
_instance
,
cls
):
cls
.
_instance
=
object
.
__new__
(
cls
)
return
cls
.
_instance
if
__name__
==
'__main__'
:
c
=
COCODetection
(
cfg
.
DATA
.
BASEDIR
,
'train2014'
)
gt_boxes
=
c
.
load
(
add_gt
=
True
,
add_mask
=
True
)
print
(
"#Images:"
,
len
(
gt_boxes
))
DetectionDataset
()
.
print_class_histogram
(
gt_boxes
)
examples/FasterRCNN/eval.py
View file @
754e17fc
...
...
@@ -54,15 +54,15 @@ def paste_mask(box, mask, shape):
return
ret
def
detect_one
_image
(
img
,
model_func
):
def
predict
_image
(
img
,
model_func
):
"""
Run detection on one image, using the TF callable.
This function should handle the preprocessing internally.
Args:
img: an image
model_func: a callable from
TF model,
takes image and returns (boxes, probs, labels, [masks])
model_func: a callable from
the TF model.
It
takes image and returns (boxes, probs, labels, [masks])
Returns:
[DetectionResult]
...
...
@@ -90,11 +90,12 @@ def detect_one_image(img, model_func):
return
results
def
eval_coco
(
df
,
detect
_func
,
tqdm_bar
=
None
):
def
predict_dataflow
(
df
,
model
_func
,
tqdm_bar
=
None
):
"""
Args:
df: a DataFlow which produces (image, image_id)
detect_func: a callable, takes [image] and returns [DetectionResult]
model_func: a callable from the TF model.
It takes image and returns (boxes, probs, labels, [masks])
tqdm_bar: a tqdm object to be shared among multiple evaluation instances. If None,
will create a new one.
...
...
@@ -110,7 +111,7 @@ def eval_coco(df, detect_func, tqdm_bar=None):
tqdm_bar
=
stack
.
enter_context
(
tqdm
.
tqdm
(
total
=
df
.
size
(),
**
get_tqdm_kwargs
()))
for
img
,
img_id
in
df
:
results
=
detect_func
(
img
)
results
=
predict_image
(
img
,
model_func
)
for
r
in
results
:
res
=
{
'image_id'
:
img_id
,
...
...
@@ -130,24 +131,24 @@ def eval_coco(df, detect_func, tqdm_bar=None):
return
all_results
def
multithread_
eval_coco
(
dataflows
,
detect
_funcs
):
def
multithread_
predict_dataflow
(
dataflows
,
model
_funcs
):
"""
Running multiple `
eval_coco
` in multiple threads, and aggregate the results.
Running multiple `
predict_dataflow
` in multiple threads, and aggregate the results.
Args:
dataflows: a list of DataFlow to be used in :func:`
eval_coco
`
detect_funcs: a list of callable to be used in :func:`eval_coco
`
dataflows: a list of DataFlow to be used in :func:`
predict_dataflow
`
model_funcs: a list of callable to be used in :func:`predict_dataflow
`
Returns:
list of dict, in the format used by
`DetectionDataset.eval_or_save_inference_results`
"""
num_worker
=
len
(
dataflows
)
assert
len
(
dataflows
)
==
len
(
detect
_funcs
)
assert
len
(
dataflows
)
==
len
(
model
_funcs
)
with
ThreadPoolExecutor
(
max_workers
=
num_worker
,
thread_name_prefix
=
'EvalWorker'
)
as
executor
,
\
tqdm
.
tqdm
(
total
=
sum
([
df
.
size
()
for
df
in
dataflows
]))
as
pbar
:
futures
=
[]
for
dataflow
,
pred
in
zip
(
dataflows
,
detect
_funcs
):
futures
.
append
(
executor
.
submit
(
eval_coco
,
dataflow
,
pred
,
pbar
))
for
dataflow
,
pred
in
zip
(
dataflows
,
model
_funcs
):
futures
.
append
(
executor
.
submit
(
predict_dataflow
,
dataflow
,
pred
,
pbar
))
all_results
=
list
(
itertools
.
chain
(
*
[
fut
.
result
()
for
fut
in
futures
]))
return
all_results
examples/FasterRCNN/train.py
View file @
754e17fc
...
...
@@ -22,11 +22,10 @@ from tensorpack.tfutils.summary import add_moving_summary
import
model_frcnn
import
model_mrcnn
from
basemodel
import
image_preprocess
,
resnet_c4_backbone
,
resnet_conv5
,
resnet_fpn_backbone
from
coco
import
DetectionDataset
from
config
import
config
as
cfg
from
config
import
finalize_configs
from
dataset
import
DetectionDataset
from
config
import
finalize_configs
,
config
as
cfg
from
data
import
get_all_anchors
,
get_all_anchors_fpn
,
get_eval_dataflow
,
get_train_dataflow
from
eval
import
DetectionResult
,
detect_one_image
,
eval_coco
,
multithread_eval_coco
from
eval
import
DetectionResult
,
predict_image
,
predict_dataflow
,
multithread_predict_dataflow
from
model_box
import
RPNAnchors
,
clip_boxes
,
crop_and_resize
,
roi_align
from
model_cascade
import
CascadeRCNNHead
from
model_fpn
import
fpn_model
,
generate_fpn_proposals
,
multilevel_roi_align
,
multilevel_rpn_losses
...
...
@@ -323,7 +322,7 @@ class ResNetFPNModel(DetectionModel):
return
[]
def
visualize
(
model
,
model_path
,
nr_visualize
=
100
,
output_dir
=
'output'
):
def
do_
visualize
(
model
,
model_path
,
nr_visualize
=
100
,
output_dir
=
'output'
):
"""
Visualize some intermediate results (proposals, raw predictions) inside the pipeline.
"""
...
...
@@ -375,31 +374,27 @@ def visualize(model, model_path, nr_visualize=100, output_dir='output'):
pbar
.
update
()
def
offline
_evaluate
(
pred_config
,
output_file
):
def
do
_evaluate
(
pred_config
,
output_file
):
num_gpu
=
cfg
.
TRAIN
.
NUM_GPUS
graph_funcs
=
MultiTowerOfflinePredictor
(
pred_config
,
list
(
range
(
num_gpu
)))
.
get_predictors
()
predictors
=
[]
for
k
in
range
(
num_gpu
):
predictors
.
append
(
lambda
img
,
pred
=
graph_funcs
[
k
]:
detect_one_image
(
img
,
pred
))
for
dataset
in
cfg
.
DATA
.
VAL
:
logger
.
info
(
"Evaluating {} ..."
.
format
(
dataset
))
dataflows
=
[
get_eval_dataflow
(
dataset
,
shard
=
k
,
num_shards
=
num_gpu
)
for
k
in
range
(
num_gpu
)]
if
num_gpu
>
1
:
all_results
=
multithread_
eval_coco
(
dataflows
,
predictor
s
)
all_results
=
multithread_
predict_dataflow
(
dataflows
,
graph_func
s
)
else
:
all_results
=
eval_coco
(
dataflows
[
0
],
predictor
s
[
0
])
all_results
=
predict_dataflow
(
dataflows
[
0
],
graph_func
s
[
0
])
output
=
output_file
+
'-'
+
dataset
DetectionDataset
()
.
eval_or_save_inference_results
(
all_results
,
dataset
,
output
)
def
predict
(
pred_func
,
input_file
):
def
do_
predict
(
pred_func
,
input_file
):
img
=
cv2
.
imread
(
input_file
,
cv2
.
IMREAD_COLOR
)
results
=
detect_one
_image
(
img
,
pred_func
)
results
=
predict
_image
(
img
,
pred_func
)
final
=
draw_final_outputs
(
img
,
results
)
viz
=
np
.
concatenate
((
img
,
final
),
axis
=
1
)
cv2
.
imwrite
(
"output.png"
,
viz
)
...
...
@@ -427,7 +422,7 @@ class EvalCallback(Callback):
# Use two predictor threads per GPU to get better throughput
self
.
num_predictor
=
num_gpu
if
buggy_tf
else
num_gpu
*
2
self
.
predictors
=
[
self
.
_build_
coco_
predictor
(
k
%
num_gpu
)
for
k
in
range
(
self
.
num_predictor
)]
self
.
predictors
=
[
self
.
_build_predictor
(
k
%
num_gpu
)
for
k
in
range
(
self
.
num_predictor
)]
self
.
dataflows
=
[
get_eval_dataflow
(
self
.
_eval_dataset
,
shard
=
k
,
num_shards
=
self
.
num_predictor
)
for
k
in
range
(
self
.
num_predictor
)]
...
...
@@ -436,15 +431,14 @@ class EvalCallback(Callback):
# Alternatively, can eval on all ranks and use allgather, but allgather sometimes hangs
self
.
_horovod_run_eval
=
hvd
.
rank
()
==
hvd
.
local_rank
()
if
self
.
_horovod_run_eval
:
self
.
predictor
=
self
.
_build_
coco_
predictor
(
0
)
self
.
predictor
=
self
.
_build_predictor
(
0
)
self
.
dataflow
=
get_eval_dataflow
(
self
.
_eval_dataset
,
shard
=
hvd
.
local_rank
(),
num_shards
=
hvd
.
local_size
())
self
.
barrier
=
hvd
.
allreduce
(
tf
.
random_normal
(
shape
=
[
1
]))
def
_build_coco_predictor
(
self
,
idx
):
graph_func
=
self
.
trainer
.
get_predictor
(
self
.
_in_names
,
self
.
_out_names
,
device
=
idx
)
return
lambda
img
:
detect_one_image
(
img
,
graph_func
)
def
_build_predictor
(
self
,
idx
):
return
self
.
trainer
.
get_predictor
(
self
.
_in_names
,
self
.
_out_names
,
device
=
idx
)
def
_before_train
(
self
):
eval_period
=
cfg
.
TRAIN
.
EVAL_PERIOD
...
...
@@ -459,14 +453,14 @@ class EvalCallback(Callback):
def
_eval
(
self
):
logdir
=
args
.
logdir
if
cfg
.
TRAINER
==
'replicated'
:
all_results
=
multithread_
eval_coco
(
self
.
dataflows
,
self
.
predictors
)
all_results
=
multithread_
predict_dataflow
(
self
.
dataflows
,
self
.
predictors
)
else
:
filenames
=
[
os
.
path
.
join
(
logdir
,
'outputs{}-part{}.json'
.
format
(
self
.
global_step
,
rank
)
)
for
rank
in
range
(
hvd
.
local_size
())]
if
self
.
_horovod_run_eval
:
local_results
=
eval_coco
(
self
.
dataflow
,
self
.
predictor
)
local_results
=
predict_dataflow
(
self
.
dataflow
,
self
.
predictor
)
fname
=
filenames
[
hvd
.
local_rank
()]
with
open
(
fname
,
'w'
)
as
f
:
json
.
dump
(
local_results
,
f
)
...
...
@@ -499,7 +493,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--load'
,
help
=
'load a model for evaluation or training. Can overwrite BACKBONE.WEIGHTS'
)
parser
.
add_argument
(
'--logdir'
,
help
=
'log directory'
,
default
=
'train_log/maskrcnn'
)
parser
.
add_argument
(
'--visualize'
,
action
=
'store_true'
,
help
=
'visualize intermediate results'
)
parser
.
add_argument
(
'--evaluate'
,
help
=
"Run evaluation
on COCO
. "
parser
.
add_argument
(
'--evaluate'
,
help
=
"Run evaluation. "
"This argument is the path to the output json evaluation file"
)
parser
.
add_argument
(
'--predict'
,
help
=
"Run prediction on a given image. "
"This argument is the path to the input image file"
)
...
...
@@ -526,7 +520,7 @@ if __name__ == '__main__':
cfg
.
TEST
.
RESULT_SCORE_THRESH
=
cfg
.
TEST
.
RESULT_SCORE_THRESH_VIS
if
args
.
visualize
:
visualize
(
MODEL
,
args
.
load
)
do_
visualize
(
MODEL
,
args
.
load
)
else
:
predcfg
=
PredictConfig
(
model
=
MODEL
,
...
...
@@ -534,10 +528,10 @@ if __name__ == '__main__':
input_names
=
MODEL
.
get_inference_tensor_names
()[
0
],
output_names
=
MODEL
.
get_inference_tensor_names
()[
1
])
if
args
.
predict
:
predict
(
OfflinePredictor
(
predcfg
),
args
.
predict
)
do_
predict
(
OfflinePredictor
(
predcfg
),
args
.
predict
)
elif
args
.
evaluate
:
assert
args
.
evaluate
.
endswith
(
'.json'
),
args
.
evaluate
offline
_evaluate
(
predcfg
,
args
.
evaluate
)
do
_evaluate
(
predcfg
,
args
.
evaluate
)
else
:
is_horovod
=
cfg
.
TRAINER
==
'horovod'
if
is_horovod
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment