重新思考表面法线估计的归纳偏置
论文的官方实现
重新思考表面法线估计的归纳偏置
CVPR 2024 [口头报告]
Gwangbin Bae 和 Andrew J. Davison
[论文PDF] [arXiv] [YouTube] [项目页面]
摘要
尽管对精确表面法线估计模型的需求日益增长,但现有方法仍使用通用的密集预测模型,采用与其他任务相同的归纳偏置。在本文中,我们讨论了表面法线估计所需的归纳偏置,并提出**(1)利用每个像素的射线方向和(2)通过学习相邻表面法线之间的相对旋转来编码它们之间的关系**。所提出的方法可以为具有任意分辨率和纵横比的具有挑战性的真实图像生成清晰但分段平滑的预测结果。与最近基于ViT的最先进模型相比,我们的方法表现出更强的泛化能力,尽管在规模小几个数量级的数据集上进行训练。
入门指南
我们提供四个步骤的说明(点击"▸"展开)。例如,如果你只想在一些图像上测试DSINE,可以在步骤1后停止。这将最大限度地减少安装/下载量。
步骤1. 在一些图像上测试DSINE(需要最少的依赖项)
首先安装依赖项。
conda create --name DSINE python=3.10
conda activate DSINE
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
python -m pip install geffnet
然后,从此链接下载模型权重,并将其保存在projects/dsine/checkpoints/
下。请注意,它应保持与Google Drive相同的文件夹结构。例如,checkpoints/exp001_cvpr2024/dsine.pt
(在Google Drive中)是我们的最佳模型。它应该被保存为projects/dsine/checkpoints/exp001_cvpr2024/dsine.pt
。相应的配置文件是projects/dsine/experiments/exp001_cvpr2024/dsine.txt
。
checkpoints/exp002_kappa/
(在Google Drive中)下的模型也可以估计不确定性。
然后,移动到projects/dsine/
文件夹,运行
python test_minimal.py ./experiments/exp001_cvpr2024/dsine.txt
这将为projects/dsine/samples/img/
下的图像生成预测。结果将保存在projects/dsine/samples/output/
下。
我们的模型假设已知相机内参,但提供近似内参仍然可以得到良好的结果。对于projects/dsine/samples/img/
中的一些图像,相应的相机内参(fx、fy、cx、cy - 假设透视相机无畸变)以.txt
文件提供。如果不存在这样的文件,内参将被近似,假设60°视场。
步骤2. 在基准数据集上测试DSINE并运行实时演示
安装额外的依赖项。
python -m pip install tensorboard
python -m pip install opencv-python
python -m pip install matplotlib
python -m pip install pyrealsense2 # 仅用于使用realsense相机的演示
python -m pip install vidgear # 仅用于YouTube视频的演示
python -m pip install yt_dlp # 仅用于YouTube视频的演示
python -m pip install mss # 仅用于屏幕捕捉的演示
从此链接下载评估数据集(dsine_eval.zip
)。
**注意:**下载数据集即表示您同意每个数据集各自的许可协议。每个数据集的链接可以在相应的readme.txt
中找到。
如果你查看projects/__init__.py
,有一个名为DATASET_DIR
和EXPERIMENT_DIR
的变量:
DATASET_DIR
是存储数据集的位置。例如,dsine_eval
数据集(从上面的链接下载)应该保存在DATASET_DIR/dsine_eval
下。更新这个变量。EXPERIMENT_DIR
是保存实验(例如模型权重、日志等)的位置。更新这个变量。
然后,移动到projects/dsine/
文件夹,运行:
# 在六个评估数据集上获取基准性能
python test.py ./experiments/exp001_cvpr2024/dsine.txt --mode benchmark
# 在六个评估数据集上获取基准性能(带可视化)
# 结果将保存在EXPERIMENT_DIR/dsine/exp001_cvpr2024/dsine/test/下
python test.py ./experiments/exp001_cvpr2024/dsine.txt --mode benchmark --visualize
# 为`projects/dsine/samples/img/`中的图像生成预测
python test.py ./experiments/exp001_cvpr2024/dsine.txt --mode samples
# 测量您设备上的吞吐量(推理速度)
python test.py ./experiments/exp001_cvpr2024/dsine.txt --mode throughput
你也可以运行实时演示:
# 捕捉你的屏幕并进行预测
python test.py ./experiments/exp001_cvpr2024/dsine.txt --mode screen
# 使用网络摄像头的演示
python test.py ./experiments/exp001_cvpr2024/dsine.txt --mode webcam
# 使用realsense相机的演示
python test.py ./experiments/exp001_cvpr2024/dsine.txt --mode rs
# 在YouTube视频上的演示(替换为不同的链接)
python test.py ./experiments/exp001_cvpr2024/dsine.txt --mode https://www.youtube.com/watch?v=X-iEq8hWd6k
对于每个输入选项,都有一些额外的参数。有关更多信息,请参见projects/dsine/test.py
。
你也可以尝试构建自己的实时演示。更多信息请参见这个笔记本。
步骤3. 训练DSINE
在projects/dsine/
中运行:
python train.py ./experiments/exp000_test/test.txt
然后执行tensorboard --logdir EXPERIMENT_DIR/dsine/exp000_test/test/log
来打开tensorboard。
这将在NYUv2数据集的训练集上训练模型,该数据集应位于DATASET_DIR/dsine_eval/nyuv2/train/
下。这里只有795张图像,性能不会很好。要获得更好的结果,你需要:
(1) 创建自定义数据加载器
我们正在检查是否可以发布整个训练数据集(约400GB)。在发布之前,你可以尝试构建自定义数据加载器。你需要定义一个
get_sample(args, sample_path, info)
函数,并在data/datasets
中提供数据分割。查看其他数据集是如何定义/提供的。你还需要更新projects/baseline_normal/dataloader.py
,以便可以使用新定义的get_sample
函数。
(2) 生成GT表面法线(可选)
如果你的数据集没有提供地面真实表面法线图,你可以尝试从地面真实深度图生成它们。更多信息请参见这个笔记本。
(3) 自定义数据增强
如果你使用合成图像,你需要正确的数据增强函数集来最小化合成到真实的域间差距。我们提供了广泛的增强函数,但超参数未经过微调,你可以通过微调它们来获得潜在的更好结果。更多信息请参见这个笔记本。
步骤4. 开始你自己的表面法线估计项目
如果你想开始自己的表面法线估计项目,可以非常容易地做到。
首先,看看projects/baseline_normal
。这是一个你可以尝试不同CNN架构而不用担心相机内参和旋转估计的地方。你可以尝试流行的架构如U-Net,并尝试不同的骨干网络。在这个文件夹中,你可以运行:
python train.py ./experiments/exp000_test/test.txt
特定项目的config
在projects/baseline_normal/config.py
中定义。所有项目共享的默认配置在projects/__init__.py
中。
数据加载器在projects/baseline_normal/dataloader.py
中。我们在dsine
项目中使用相同的数据加载器,所以我们没有projects/dsine/dataloader.py
。
损失函数定义在 projects/baseline_normal/losses.py
中。这些是用于在您自己的项目中构建自定义损失函数的基础模块。例如,在 DSINE 项目中,我们生成了一系列预测结果,损失函数是对每个预测计算的损失的加权和。您可以在 projects/dsine/losses.py
中看到具体实现方式。
您可以通过复制 projects/dsine
文件夹来创建 projects/NEW_PROJECT_NAME
,从而开始一个新项目。然后,更新 config.py
和 losses.py
。
最后,您应该修改 train.py
和 test.py
。对于在不同项目中应该有所不同的部分,我们做了如下标记:
#↓↓↓↓
#注意:前向传播
img = data_dict['img'].to(device)
intrins = data_dict['intrins'].to(device)
...
pred_list = model(img, intrins=intrins, mode='test')
norm_out = pred_list[-1]
#↑↑↑↑
搜索箭头(↓↓↓↓/↑↑↑↑)以查看在不同项目中需要修改的地方。
上述测试命令(例如获取基准性能和运行实时演示)应适用于所有项目。
附加说明
如果您想为此仓库做出贡献,请提交拉取请求并按以下格式添加说明。
使用 torch hub 预测法线(由 hugoycj 贡献)
注意:以下代码已过时,应进行修改(因为文件夹结构已更改)。
import torch
import cv2
import numpy as np
# 从 torch hub 加载法线预测模型
normal_predictor = torch.hub.load("hugoycj/DSINE-hub", "DSINE", trust_repo=True)
# 使用 OpenCV 加载输入图像
image = cv2.imread(args.input, cv2.IMREAD_COLOR)
h, w = image.shape[:2]
# 使用模型从输入图像推断法线图
with torch.inference_mode():
normal = normal_predictor.infer_cv2(image)[0] # 输出形状:(H, W, 3)
normal = (normal + 1) / 2 # 将值转换到 [0, 1] 范围内
# 将法线图转换为可显示格式
normal = (normal * 255).cpu().numpy().astype(np.uint8).transpose(1, 2, 0)
normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR)
# 将输出法线图保存到文件
cv2.imwrite(args.output, normal)
如果网络无法获取权重,您可以使用本地权重进行 torch hub 加载,如下所示:
normal_predictor = torch.hub.load("hugoycj/DSINE-hub", "DSINE", local_file_path='./checkpoints/dsine.pt', trust_repo=True)
生成地面真实表面法线
我们提供了用于从地面真实深度图生成地面真实表面法线的代码。有关更多信息,请参阅此笔记本。关于坐标系
我们使用右手坐标系,其中 (X, Y, Z) = (右, 下, 前)。需要注意的一个重要点是,地面真实法线和我们的预测结果都是外向法线。例如,对于面向相机的正面平行墙,法线应为 (0, 0, 1),而不是 (0, 0, -1)。如果您需要使用内向法线,请执行normals = -normals
。
分享您的模型权重
如果您希望分享您的模型权重,请通过提供相应的配置文件和权重链接来提交拉取请求。引用
如果您在研究中发现我们的工作有用,请考虑引用我们的论文:
@inproceedings{bae2024dsine,
title = {Rethinking Inductive Biases for Surface Normal Estimation},
author = {Gwangbin Bae and Andrew J. Davison},
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2024}
}
如果您使用的模型还估计不确定性,请同时引用以下论文,其中我们介绍了损失函数:
@InProceedings{bae2021eesnu,
title = {Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation}
author = {Gwangbin Bae and Ignas Budvytis and Roberto Cipolla},
booktitle = {International Conference on Computer Vision (ICCV)},
year = {2021}
}