Commit 3db9af2b authored by Yuxin Wu's avatar Yuxin Wu

flownet inference for a seq of images

parent ca096908
......@@ -37,10 +37,13 @@ wget http://models.tensorpack.com/OpticalFlow/flownet2-c.npz
```bash
python flownet2.py \
--left left.png --right right.png \
--images frame0.png frame1.png frame2.png
--load flownet2.npz --model flownet2
```
This command will show predictions for all the consecutive pairs one by one.
Press any key to visualize the next prediction.
3. Evaluate AEE (Average Endpoing Error) on Sintel dataset:
```
......
......@@ -14,15 +14,13 @@ import flownet_models as models
from helper import Flow
def apply(model, model_path, left, right, ground_truth=None):
left = cv2.imread(left)
right = cv2.imread(right)
def apply(model, model_path, images, ground_truth=None):
left = cv2.imread(images[0])
h, w = left.shape[:2]
newh = (h // 64) * 64
neww = (w // 64) * 64
aug = imgaug.CenterCrop((newh, neww))
left, right = aug.augment(left), aug.augment(right)
left = aug.augment(left)
predict_func = OfflinePredictor(PredictConfig(
model=model(height=newh, width=neww),
......@@ -30,6 +28,9 @@ def apply(model, model_path, left, right, ground_truth=None):
input_names=['left', 'right'],
output_names=['prediction']))
for right in images[1:]:
right = aug.augment(cv2.imread(right))
left_input, right_input = [x.astype('float32').transpose(2, 0, 1)[None, ...]
for x in [left, right]]
output = predict_func(left_input, right_input)[0].transpose(0, 2, 3, 1)
......@@ -45,6 +46,8 @@ def apply(model, model_path, left, right, ground_truth=None):
cv2.imwrite('flow_prediction.png', img)
cv2.waitKey(0)
left = right
class SintelData(DataFlow):
......@@ -118,8 +121,8 @@ if __name__ == '__main__':
parser.add_argument('--load', help='path to the model', required=True)
parser.add_argument('--model', help='model',
choices=['flownet2', 'flownet2-s', 'flownet2-c'], required=True)
parser.add_argument('--left', help='input')
parser.add_argument('--right', help='input')
parser.add_argument('--images', nargs="+",
help='a list of equally-sized images. FlowNet will be applied to all consecutive pairs')
parser.add_argument('--gt', help='path to ground truth flow')
parser.add_argument('--sintel_path', help='path to sintel dataset')
args = parser.parse_args()
......@@ -131,4 +134,5 @@ if __name__ == '__main__':
if args.sintel_path:
inference(model, args.load, args.sintel_path)
else:
apply(model, args.load, args.left, args.right, args.gt)
assert len(args.images) >= 2
apply(model, args.load, args.images, args.gt)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment