Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d3f11e3f
Commit
d3f11e3f
authored
Aug 07, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[MaskRCNN] multi-GPU validation
parent
9dba9893
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
20 deletions
+67
-20
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+12
-1
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+12
-5
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+10
-3
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+33
-11
No files found.
examples/FasterRCNN/config.py
View file @
d3f11e3f
...
@@ -11,7 +11,13 @@ __all__ = ['config', 'finalize_configs']
...
@@ -11,7 +11,13 @@ __all__ = ['config', 'finalize_configs']
class
AttrDict
():
class
AttrDict
():
_freezed
=
False
""" Avoid accidental creation of new hierarchies. """
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
):
if
self
.
_freezed
:
raise
AttributeError
(
name
)
ret
=
AttrDict
()
ret
=
AttrDict
()
setattr
(
self
,
name
,
ret
)
setattr
(
self
,
name
,
ret
)
return
ret
return
ret
...
@@ -24,7 +30,7 @@ class AttrDict():
...
@@ -24,7 +30,7 @@ class AttrDict():
def
to_dict
(
self
):
def
to_dict
(
self
):
"""Convert to a nested dict. """
"""Convert to a nested dict. """
return
{
k
:
v
.
to_dict
()
if
isinstance
(
v
,
AttrDict
)
else
v
return
{
k
:
v
.
to_dict
()
if
isinstance
(
v
,
AttrDict
)
else
v
for
k
,
v
in
self
.
__dict__
.
items
()}
for
k
,
v
in
self
.
__dict__
.
items
()
if
not
k
.
startswith
(
'_'
)
}
def
update_args
(
self
,
args
):
def
update_args
(
self
,
args
):
"""Update from command line args. """
"""Update from command line args. """
...
@@ -43,6 +49,9 @@ class AttrDict():
...
@@ -43,6 +49,9 @@ class AttrDict():
v
=
eval
(
v
)
v
=
eval
(
v
)
setattr
(
dic
,
key
,
v
)
setattr
(
dic
,
key
,
v
)
def
freeze
(
self
):
self
.
_freezed
=
True
# avoid silent bugs
# avoid silent bugs
def
__eq__
(
self
,
_
):
def
__eq__
(
self
,
_
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -94,6 +103,7 @@ _C.TRAIN.STEPS_PER_EPOCH = 500
...
@@ -94,6 +103,7 @@ _C.TRAIN.STEPS_PER_EPOCH = 500
# Otherwise the actual steps to decrease learning rate are computed from the schedule.
# Otherwise the actual steps to decrease learning rate are computed from the schedule.
# LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
# LR_SCHEDULE = [120000, 160000, 180000] # "1x" schedule in detectron
_C
.
TRAIN
.
LR_SCHEDULE
=
[
240000
,
320000
,
360000
]
# "2x" schedule in detectron
_C
.
TRAIN
.
LR_SCHEDULE
=
[
240000
,
320000
,
360000
]
# "2x" schedule in detectron
_C
.
TRAIN
.
NUM_EVALS
=
20
# number of evaluations to run during training
# preprocessing --------------------
# preprocessing --------------------
# Alternative old (worse & faster) setting: 600, 1024
# Alternative old (worse & faster) setting: 600, 1024
...
@@ -208,4 +218,5 @@ def finalize_configs(is_training):
...
@@ -208,4 +218,5 @@ def finalize_configs(is_training):
# autotune is too slow for inference
# autotune is too slow for inference
os
.
environ
[
'TF_CUDNN_USE_AUTOTUNE'
]
=
'0'
os
.
environ
[
'TF_CUDNN_USE_AUTOTUNE'
]
=
'0'
_C
.
freeze
()
logger
.
info
(
"Config: ------------------------------------------
\n
"
+
str
(
_C
))
logger
.
info
(
"Config: ------------------------------------------
\n
"
+
str
(
_C
))
examples/FasterRCNN/data.py
View file @
d3f11e3f
...
@@ -9,7 +9,7 @@ import itertools
...
@@ -9,7 +9,7 @@ import itertools
from
tensorpack.utils.argtools
import
memoized
,
log_once
from
tensorpack.utils.argtools
import
memoized
,
log_once
from
tensorpack.dataflow
import
(
from
tensorpack.dataflow
import
(
imgaug
,
TestDataSpeed
,
imgaug
,
TestDataSpeed
,
PrefetchDataZMQ
,
MultiProcessMapDataZMQ
,
MultiThreadMapData
,
MultiProcessMapDataZMQ
,
MultiThreadMapData
,
MapDataComponent
,
DataFromList
)
MapDataComponent
,
DataFromList
)
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
# import tensorpack.utils.viz as tpviz
# import tensorpack.utils.viz as tpviz
...
@@ -381,18 +381,25 @@ def get_train_dataflow():
...
@@ -381,18 +381,25 @@ def get_train_dataflow():
return
ds
return
ds
def
get_eval_dataflow
():
def
get_eval_dataflow
(
shard
=
0
,
num_shards
=
1
):
"""
Args:
shard, num_shards: to get subset of evaluation data
"""
imgs
=
COCODetection
.
load_many
(
cfg
.
DATA
.
BASEDIR
,
cfg
.
DATA
.
VAL
,
add_gt
=
False
)
imgs
=
COCODetection
.
load_many
(
cfg
.
DATA
.
BASEDIR
,
cfg
.
DATA
.
VAL
,
add_gt
=
False
)
num_imgs
=
len
(
imgs
)
img_per_shard
=
num_imgs
//
num_shards
img_range
=
(
shard
*
img_per_shard
,
(
shard
+
1
)
*
img_per_shard
if
shard
+
1
<
num_shards
else
num_imgs
)
# no filter for training
# no filter for training
ds
=
DataFromListOfDict
(
imgs
,
[
'file_name'
,
'id'
])
ds
=
DataFromListOfDict
(
imgs
[
img_range
[
0
]:
img_range
[
1
]]
,
[
'file_name'
,
'id'
])
def
f
(
fname
):
def
f
(
fname
):
im
=
cv2
.
imread
(
fname
,
cv2
.
IMREAD_COLOR
)
im
=
cv2
.
imread
(
fname
,
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
,
fname
assert
im
is
not
None
,
fname
return
im
return
im
ds
=
MapDataComponent
(
ds
,
f
,
0
)
ds
=
MapDataComponent
(
ds
,
f
,
0
)
if
cfg
.
TRAINER
!=
'horovod'
:
# Evaluation itself may be multi-threaded, therefore don't add prefetch here.
ds
=
PrefetchDataZMQ
(
ds
,
1
)
return
ds
return
ds
...
...
examples/FasterRCNN/eval.py
View file @
d3f11e3f
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
tqdm
import
tqdm
import
os
import
os
from
collections
import
namedtuple
from
collections
import
namedtuple
from
contextlib
import
ExitStack
import
numpy
as
np
import
numpy
as
np
import
cv2
import
cv2
...
@@ -90,18 +91,24 @@ def detect_one_image(img, model_func):
...
@@ -90,18 +91,24 @@ def detect_one_image(img, model_func):
return
results
return
results
def
eval_coco
(
df
,
detect_func
):
def
eval_coco
(
df
,
detect_func
,
tqdm_bar
=
None
):
"""
"""
Args:
Args:
df: a DataFlow which produces (image, image_id)
df: a DataFlow which produces (image, image_id)
detect_func: a callable, takes [image] and returns [DetectionResult]
detect_func: a callable, takes [image] and returns [DetectionResult]
tqdm_bar: a tqdm object to be shared among multiple evaluation instances. If None,
will create a new one.
Returns:
Returns:
list of dict, to be dumped to COCO json format
list of dict, to be dumped to COCO json format
"""
"""
df
.
reset_state
()
df
.
reset_state
()
all_results
=
[]
all_results
=
[]
with
tqdm
.
tqdm
(
total
=
df
.
size
(),
**
get_tqdm_kwargs
())
as
pbar
:
# tqdm is not quite thread-safe: https://github.com/tqdm/tqdm/issues/323
with
ExitStack
()
as
stack
:
if
tqdm_bar
is
None
:
tqdm_bar
=
stack
.
enter_context
(
tqdm
.
tqdm
(
total
=
df
.
size
(),
**
get_tqdm_kwargs
()))
for
img
,
img_id
in
df
.
get_data
():
for
img
,
img_id
in
df
.
get_data
():
results
=
detect_func
(
img
)
results
=
detect_func
(
img
)
for
r
in
results
:
for
r
in
results
:
...
@@ -124,7 +131,7 @@ def eval_coco(df, detect_func):
...
@@ -124,7 +131,7 @@ def eval_coco(df, detect_func):
rle
[
'counts'
]
=
rle
[
'counts'
]
.
decode
(
'ascii'
)
rle
[
'counts'
]
=
rle
[
'counts'
]
.
decode
(
'ascii'
)
res
[
'segmentation'
]
=
rle
res
[
'segmentation'
]
=
rle
all_results
.
append
(
res
)
all_results
.
append
(
res
)
p
bar
.
update
(
1
)
tqdm_
bar
.
update
(
1
)
return
all_results
return
all_results
...
...
examples/FasterRCNN/train.py
View file @
d3f11e3f
...
@@ -12,6 +12,7 @@ import numpy as np
...
@@ -12,6 +12,7 @@ import numpy as np
import
json
import
json
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
concurrent.futures
import
ThreadPoolExecutor
try
:
try
:
import
horovod.tensorflow
as
hvd
import
horovod.tensorflow
as
hvd
except
ImportError
:
except
ImportError
:
...
@@ -466,33 +467,54 @@ def predict(pred_func, input_file):
...
@@ -466,33 +467,54 @@ def predict(pred_func, input_file):
class
EvalCallback
(
Callback
):
class
EvalCallback
(
Callback
):
"""
A callback that runs COCO evaluation once a while.
It supports multi-GPU evaluation if TRAINER=='replicated' and single-GPU evaluation if TRAINER=='horovod'
"""
def
__init__
(
self
,
in_names
,
out_names
):
def
__init__
(
self
,
in_names
,
out_names
):
self
.
_in_names
,
self
.
_out_names
=
in_names
,
out_names
self
.
_in_names
,
self
.
_out_names
=
in_names
,
out_names
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
pred
=
self
.
trainer
.
get_predictor
(
self
.
_in_names
,
self
.
_out_names
)
num_gpu
=
cfg
.
TRAIN
.
NUM_GPUS
self
.
df
=
get_eval_dataflow
()
# Use two predictor threads per GPU to get better throughput
self
.
num_predictor
=
1
if
cfg
.
TRAINER
==
'horovod'
else
num_gpu
*
2
self
.
predictors
=
[
self
.
_build_coco_predictor
(
k
%
num_gpu
)
for
k
in
range
(
self
.
num_predictor
)]
self
.
dataflows
=
[
get_eval_dataflow
(
shard
=
k
,
num_shards
=
self
.
num_predictor
)
for
k
in
range
(
self
.
num_predictor
)]
def
_build_coco_predictor
(
self
,
idx
):
graph_func
=
self
.
trainer
.
get_predictor
(
self
.
_in_names
,
self
.
_out_names
,
device
=
idx
)
return
lambda
img
:
detect_one_image
(
img
,
graph_func
)
def
_before_train
(
self
):
def
_before_train
(
self
):
EVAL_TIMES
=
5
# eval 5 times during training
num_eval
=
cfg
.
TRAIN
.
NUM_EVALS
interval
=
self
.
trainer
.
max_epoch
//
(
EVAL_TIMES
+
1
)
interval
=
max
(
self
.
trainer
.
max_epoch
//
(
num_eval
+
1
),
1
)
self
.
epochs_to_eval
=
set
([
interval
*
k
for
k
in
range
(
1
,
EVAL_TIMES
+
1
)])
self
.
epochs_to_eval
=
set
([
interval
*
k
for
k
in
range
(
1
,
num_eval
+
1
)])
self
.
epochs_to_eval
.
add
(
self
.
trainer
.
max_epoch
)
self
.
epochs_to_eval
.
add
(
self
.
trainer
.
max_epoch
)
logger
.
info
(
"[EvalCallback] Will evaluate at epoch "
+
str
(
sorted
(
self
.
epochs_to_eval
)))
if
len
(
self
.
epochs_to_eval
)
<
15
:
logger
.
info
(
"[EvalCallback] Will evaluate at epoch "
+
str
(
sorted
(
self
.
epochs_to_eval
)))
else
:
logger
.
info
(
"[EvalCallback] Will evaluate every {} epochs"
.
format
(
interval
))
def
_eval
(
self
):
def
_eval
(
self
):
all_results
=
eval_coco
(
self
.
df
,
lambda
img
:
detect_one_image
(
img
,
self
.
pred
))
with
ThreadPoolExecutor
(
max_workers
=
self
.
num_predictor
,
thread_name_prefix
=
'EvalWorker'
)
as
executor
,
\
tqdm
.
tqdm
(
total
=
sum
([
df
.
size
()
for
df
in
self
.
dataflows
]))
as
pbar
:
futures
=
[]
for
dataflow
,
pred
in
zip
(
self
.
dataflows
,
self
.
predictors
):
futures
.
append
(
executor
.
submit
(
eval_coco
,
dataflow
,
pred
,
pbar
))
all_results
=
list
(
itertools
.
chain
(
*
[
fut
.
result
()
for
fut
in
futures
]))
output_file
=
os
.
path
.
join
(
output_file
=
os
.
path
.
join
(
logger
.
get_logger_dir
(),
'outputs{}.json'
.
format
(
self
.
global_step
))
logger
.
get_logger_dir
(),
'outputs{}.json'
.
format
(
self
.
global_step
))
with
open
(
output_file
,
'w'
)
as
f
:
with
open
(
output_file
,
'w'
)
as
f
:
json
.
dump
(
all_results
,
f
)
json
.
dump
(
all_results
,
f
)
try
:
try
:
scores
=
print_evaluation_scores
(
output_file
)
scores
=
print_evaluation_scores
(
output_file
)
for
k
,
v
in
scores
.
items
():
self
.
trainer
.
monitors
.
put_scalar
(
k
,
v
)
except
Exception
:
except
Exception
:
logger
.
exception
(
"Exception in COCO evaluation."
)
logger
.
exception
(
"Exception in COCO evaluation."
)
scores
=
{}
for
k
,
v
in
scores
.
items
():
self
.
trainer
.
monitors
.
put_scalar
(
k
,
v
)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
epoch_num
in
self
.
epochs_to_eval
:
if
self
.
epoch_num
in
self
.
epochs_to_eval
:
...
@@ -558,7 +580,7 @@ if __name__ == '__main__':
...
@@ -558,7 +580,7 @@ if __name__ == '__main__':
init_lr
=
cfg
.
TRAIN
.
BASE_LR
*
0.33
*
(
8.
/
cfg
.
TRAIN
.
NUM_GPUS
)
init_lr
=
cfg
.
TRAIN
.
BASE_LR
*
0.33
*
(
8.
/
cfg
.
TRAIN
.
NUM_GPUS
)
warmup_schedule
=
[(
0
,
init_lr
),
(
cfg
.
TRAIN
.
WARMUP
,
cfg
.
TRAIN
.
BASE_LR
)]
warmup_schedule
=
[(
0
,
init_lr
),
(
cfg
.
TRAIN
.
WARMUP
,
cfg
.
TRAIN
.
BASE_LR
)]
warmup_end_epoch
=
cfg
.
TRAIN
.
WARMUP
*
1.
/
stepnum
warmup_end_epoch
=
cfg
.
TRAIN
.
WARMUP
*
1.
/
stepnum
lr_schedule
=
[(
int
(
np
.
ceil
(
warmup_end_epoch
)),
warmup_schedule
[
-
1
][
1
]
)]
lr_schedule
=
[(
int
(
np
.
ceil
(
warmup_end_epoch
)),
cfg
.
TRAIN
.
BASE_LR
)]
factor
=
8.
/
cfg
.
TRAIN
.
NUM_GPUS
factor
=
8.
/
cfg
.
TRAIN
.
NUM_GPUS
for
idx
,
steps
in
enumerate
(
cfg
.
TRAIN
.
LR_SCHEDULE
[:
-
1
]):
for
idx
,
steps
in
enumerate
(
cfg
.
TRAIN
.
LR_SCHEDULE
[:
-
1
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment