GMFlow
论文的官方 PyTorch 实现:
GMFlow:通过全局匹配学习光流, CVPR 2022, 口头报告
作者: 徐浩飞, 张静, 蔡剑飞, Hamid Rezatofighi, 陶大程
2022年11月15日更新:查看我们的新工作:统一流、立体视觉和深度估计以及代码:unimatch,用于将GMFlow扩展到立体视觉和深度任务。更多预训练的GMFlow模型提供不同的速度-精度权衡也已发布。查看我们的Colab和HuggingFace演示,在浏览器中体验GMFlow!
GMFlow的视频介绍(中文)现已在哔哩哔哩上线!
我们通过将光流重新表述为全局匹配问题来简化光流估计流程。
亮点
-
灵活和模块化设计
我们将端到端光流框架分解为五个组件:
特征提取、特征增强、特征匹配、流传播和流细化。
通过组合不同的组件,可以轻松构建定制的光流模型。
-
高精度
仅经过一次细化,GMFlow在具有挑战性的Sintel基准测试中就优于经过31次细化的RAFT。
-
高效率
基本的GMFlow模型(无细化)在Sintel数据(436x1024)上运行时间为57毫秒(V100)或26毫秒(A100)。
由于GMFlow不需要大量的顺序计算,因此在高端GPU(如A100)上获得比RAFT更多的加速。
GMFlow还简化了反向流计算,无需两次前向网络。双向流可用于通过前向-后向一致性检查进行遮挡检测。
安装
我们的代码基于pytorch 1.9.0、CUDA 10.2和python 3.8。更高版本的pytorch也应该可以正常工作。
我们建议使用conda进行安装:
conda env create -f environment.yml
conda activate gmflow
演示
所有预训练模型可以从Google Drive下载。
你可以在一系列图像上运行训练好的模型并可视化结果:
CUDA_VISIBLE_DEVICES=0 python main.py \
--inference_dir demo/sintel_market_1 \
--output_path output/gmflow-norefine-sintel_market_1 \
--resume pretrained/gmflow_sintel-0c07dcb3.pth
你还可以通过启用--pred_bidir_flow
来预测双向流,并使用--fwd_bwd_consistency_check
进行前向-后向一致性检查。更多示例可以在scripts/demo.sh中找到。
数据集
用于训练和评估GMFlow的数据集如下:
默认情况下,数据加载器datasets.py假设数据集位于datasets
文件夹中,并按以下方式组织:
datasets
├── FlyingChairs_release
│ └── data
├── FlyingThings3D
│ ├── frames_cleanpass
│ ├── frames_finalpass
│ └── optical_flow
├── HD1K
│ ├── hd1k_challenge
│ ├── hd1k_flow_gt
│ ├── hd1k_flow_uncertainty
│ └── hd1k_input
├── KITTI
│ ├── testing
│ └── training
├── Sintel
│ ├── test
│ └── training
建议将数据集根目录符号链接到datasets
:
ln -s $YOUR_DATASET_ROOT datasets
否则,你可能需要更改datasets.py中的相应路径。
评估
你可以通过运行以下命令来评估训练好的GMFlow模型:
CUDA_VISIBLE_DEVICES=0 python main.py --eval --val_dataset things sintel --resume pretrained/gmflow_things-e9887eda.pth
更多评估脚本可以在scripts/evaluate.sh中找到。
对于提交到Sintel和KITTI在线测试集,你可以运行scripts/submission.sh。
训练
所有在FlyingChairs、FlyingThings3D、Sintel和KITTI数据集上的训练脚本可以在scripts/train_gmflow.sh和scripts/train_gmflow_with_refine.sh中找到。
请注意,基本的GMFlow模型(无细化)可以在4个16GB V100 GPU上训练。对于带细化的GMFlow训练,默认情况下需要8个16GB V100或4个32GB V100或4个40GB A100 GPU。你可能需要根据你的硬件调整批量大小和训练迭代次数。
我们支持使用tensorboard来监控和可视化训练过程。你可以首先启动一个tensorboard会话:
tensorboard --logdir checkpoints
然后在浏览器中访问http://localhost:6006。
引用
如果你发现我们的工作对你的研究有用,请考虑引用我们的论文:
@inproceedings{xu2022gmflow,
title={GMFlow: Learning Optical Flow via Global Matching},
author={Xu, Haofei and Zhang, Jing and Cai, Jianfei and Rezatofighi, Hamid and Tao, Dacheng},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={8121-8130},
year={2022}
}
致谢
如果没有依赖一些优秀的代码库,这个项目是不可能完成的:RAFT、LoFTR、DETR、Swin、mmdetection和Detectron2。我们感谢原作者的出色工作。