Commit 671b64cc authored by Yuxin Wu's avatar Yuxin Wu

[FasterRCNN] put training preprocessing together

parent 9229285e
...@@ -11,7 +11,7 @@ from tensorpack.utils import logger ...@@ -11,7 +11,7 @@ from tensorpack.utils import logger
from tensorpack.utils.argtools import memoized, log_once from tensorpack.utils.argtools import memoized, log_once
from tensorpack.dataflow import ( from tensorpack.dataflow import (
ProxyDataFlow, MapData, imgaug, TestDataSpeed, ProxyDataFlow, MapData, imgaug, TestDataSpeed,
AugmentImageComponents, MapDataComponent) MapDataComponent)
import tensorpack.utils.viz as tpviz import tensorpack.utils.viz as tpviz
from tensorpack.utils.viz import interactive_imshow from tensorpack.utils.viz import interactive_imshow
...@@ -192,32 +192,6 @@ def get_rpn_anchor_input(im, boxes, klass, is_crowd): ...@@ -192,32 +192,6 @@ def get_rpn_anchor_input(im, boxes, klass, is_crowd):
return featuremap_labels, featuremap_boxes return featuremap_labels, featuremap_boxes
def read_and_augment_images(ds):
def mapf(dp):
fname = dp[0]
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
dp[0] = im.astype('float32')
# assume floatbox as input
assert dp[1].dtype == np.float32
dp[1] = box_to_point8(dp[1])
dp.append(fname)
return dp
ds = MapData(ds, mapf)
augs = [CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE),
imgaug.Flip(horiz=True)]
ds = AugmentImageComponents(ds, augs, index=(0,), coords_index=(1,))
def unmapf(points):
boxes = point8_to_box(points)
return boxes
ds = MapDataComponent(ds, unmapf, 1)
return ds
def get_train_dataflow(): def get_train_dataflow():
imgs = COCODetection.load_many(config.BASEDIR, config.TRAIN_DATASET) imgs = COCODetection.load_many(config.BASEDIR, config.TRAIN_DATASET)
# Valid training images should have at least one fg box. # Valid training images should have at least one fg box.
...@@ -228,25 +202,39 @@ def get_train_dataflow(): ...@@ -228,25 +202,39 @@ def get_train_dataflow():
imgs, imgs,
['file_name', 'boxes', 'class', 'is_crowd'], # we need this four keys only ['file_name', 'boxes', 'class', 'is_crowd'], # we need this four keys only
shuffle=True) shuffle=True)
ds = read_and_augment_images(ds)
def add_anchor_to_dp(dp): aug = imgaug.AugmentorList(
im, boxes, klass, is_crowd, fname = dp [CustomResize(config.SHORT_EDGE_SIZE, config.MAX_SIZE),
imgaug.Flip(horiz=True)])
def preprocess(dp):
fname, boxes, klass, is_crowd = dp
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
im = im.astype('float32')
# assume floatbox as input
assert boxes.dtype == np.float32
# augmentation:
im, params = aug.augment_return_params(im)
points = box_to_point8(boxes)
points = aug.augment_coords(points, params)
boxes = point8_to_box(points)
# rpn anchor:
try: try:
fm_labels, fm_boxes = get_rpn_anchor_input(im, boxes, klass, is_crowd) fm_labels, fm_boxes = get_rpn_anchor_input(im, boxes, klass, is_crowd)
boxes = boxes[is_crowd == 0] # skip crowd boxes in training target boxes = boxes[is_crowd == 0] # skip crowd boxes in training target
klass = klass[is_crowd == 0] klass = klass[is_crowd == 0]
if not len(boxes): if not len(boxes):
raise MalformedData("No valid gt_boxes!") raise MalformedData("No valid gt_boxes!")
except MalformedData as e: except MalformedData as e:
log_once("Input {} is invalid for training: {}".format(fname, str(e)), 'warn') log_once("Input {} is invalid for training: {}".format(fname, str(e)), 'warn')
return None return None
return [im, fm_labels, fm_boxes, boxes, klass] return im, fm_labels, fm_boxes, boxes, klass
ds = MapData(ds, add_anchor_to_dp) ds = MapData(ds, preprocess)
return ds return ds
...@@ -265,7 +253,7 @@ def get_eval_dataflow(): ...@@ -265,7 +253,7 @@ def get_eval_dataflow():
if __name__ == '__main__': if __name__ == '__main__':
from tensorpack.dataflow import PrintData from tensorpack.dataflow import PrintData
ds = get_train_dataflow('/datasets01/COCO/060817') ds = get_train_dataflow()
ds = PrintData(ds, 100) ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start() TestDataSpeed(ds, 50000).start()
ds.reset_state() ds.reset_state()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment