ESANet:用于室内场景分析的高效RGB-D语义分割
您可能还想看看我们的后续工作EMSANet(多任务方法、更好的语义分割结果,以及更清晰、更易扩展的代码库)
本仓库包含了我们论文"用于室内场景分析的高效RGB-D语义分割"的代码(IEEE Xplore,arXiv)。
我们精心设计的网络架构能够在NVIDIA Jetson AGX Xavier上实现实时语义分割,因此非常适合作为移动机器人复杂系统中实时场景分析的通用初始处理步骤:
我们的方法也可以应用于户外场景,如Cityscapes数据集:
本仓库包含了训练和评估我们网络的代码。此外,我们还提供了将模型转换为ONNX和TensorRT的代码,以及测量推理时间的代码。
许可和引用
源代码以BSD 3-Clause许可发布,详见许可文件。
如果您使用源代码或网络权重,请引用以下论文:
Seichter, D., Köhler, M., Lewandowski, B., Wengefeld T., Gross, H.-M. Efficient RGB-D Semantic Segmentation for Indoor Scene Analysis in IEEE International Conference on Robotics and Automation (ICRA), pp. 13525-13531, 2021.
@inproceedings{esanet2021icra,
title={Efficient RGB-D Semantic Segmentation for Indoor Scene Analysis},
author={Seichter, Daniel and K{\"o}hler, Mona and Lewandowski, Benjamin and Wengefeld, Tim and Gross, Horst-Michael},
booktitle={IEEE International Conference on Robotics and Automation (ICRA)},
year={2021},
volume={},
number={},
pages={13525-13531}
}
@article{esanet2020arXiv,
title={Efficient RGB-D Semantic Segmentation for Indoor Scene Analysis},
author={Seichter, Daniel and K{\"o}hler, Mona and Lewandowski, Benjamin and Wengefeld, Tim and Gross, Horst-Michael},
journal={arXiv preprint arXiv:2011.06961},
year={2020}
}
请注意,预印本已被接受在IEEE机器人与自动化国际会议(ICRA)上发表。
设置
-
克隆仓库:
git clone https://github.com/TUI-NICR/ESANet.git cd /path/to/this/repository
-
设置包含所有依赖项的 Anaconda 环境:
# 从 YAML 文件创建 conda 环境 conda env create -f rgbd_segmentation.yaml # 激活环境 conda activate rgbd_segmentation
-
数据准备(训练/评估/数据集推理):
我们在 NYUv2、 SUNRGB-D 和 Cityscapes 上训练了我们的网络。 编码器在 ImageNet 上进行了预训练。 此外,我们还在合成数据集 SceneNet RGB-D 上预训练了我们的最佳模型。src/datasets
文件夹包含了准备 NYUv2、SunRGB-D、Cityscapes、SceneNet RGB-D 用于训练和评估的代码。 请按照各数据集的说明进行操作,并将创建的数据集存储在./datasets
中。 对于 ImageNet,我们使用了 TensorFlowDatasets(参见imagenet_pretraining.py
)。 -
预训练模型(评估):
我们提供了在 NYUv2、SunRGBD 和 Cityscapes 上选定的 ESANet-R34-NBt1D(使用 ResNet34 NBt1D 主干网络)的权重:数据集 模型 mIoU FPS* 链接 NYUv2(测试) ESANet-R34-NBt1D 50.30 29.7 下载 ESANet-R34-NBt1D(预训练 SceneNet) 51.58 29.7 下载 SUNRGB-D(测试) ESANet-R34-NBt1D 48.17 29.7** 下载 ESANet-R34-NBt1D(预训练 SceneNet) 48.04 29.7** 下载 Cityscapes(验证半分辨率) ESANet-R34-NBt1D 75.22 23.4 下载 Cityscapes(验证全分辨率) ESANet-R34-NBt1D 80.09 6.2 下载 下载并解压模型到
./trained_models
。*我们报告的 FPS 是在 NVIDIA Jetson AGX Xavier(Jetpack 4.4、TensorRT 7.1、Float16)上的结果。
**注意,我们在论文中只报告了 NYUv2 的推理时间,因为它比 SUNRGB-D 的类别更多。 因此,SUNRGB-D 的 FPS 可能略高(37 vs. 40 类)。
内容
以下是不同任务的子部分:
- 评估:重现我们论文中报告的结果。
- 数据集推理:将训练好的模型应用于数据集样本。
- 样本推理:将训练好的模型应用于
./samples
中的样本。 - 时间推理:使用 TensorRT 在 NVIDIA Jetson AGX Xavier 上进行推理时间测试。
- 训练:训练新的 ESANet 模型。
评估
要重现我们论文中报告的 mIoU,请使用 eval.py
。
请注意,正确构建模型取决于模型训练所使用的相应数据集。在运行
eval.py
时不传递额外的模型参数,默认会在NYUv2或SUNRGB-D上评估我们的ESANet-R34-NBt1D。对于Cityscapes,参数有所不同。您会在网络权重文件旁边找到一个argsv_*.txt
文件,列出了所需的参数。
示例:
-
要评估在NYUv2上训练的ESANet-R34-NBt1D,运行:
python eval.py \ --dataset nyuv2 \ --dataset_dir ./datasets/nyuv2 \ --ckpt_path ./trained_models/nyuv2/r34_NBt1D.pth # 相机:kv1 mIoU:50.30 # 所有相机,mIoU:50.30
-
要评估在SUNRGB-D上训练的ESANet-R34-NBt1D,运行:
python eval.py \ --dataset sunrgbd \ --dataset_dir ./datasets/sunrgbd \ --ckpt_path ./trained_models/sunrgbd/r34_NBt1D.pth # 相机:realsense mIoU:32.42 # 相机:kv2 mIoU:46.28 # 相机:kv1 mIoU:53.39 # 相机:xtion mIoU:41.93 # 所有相机,mIoU:48.17
-
要评估在Cityscapes上训练的ESANet-R34-NBt1D,运行:
# 半分辨率(1024x512) python eval.py \ --dataset cityscapes-with-depth \ --dataset_dir ./datasets/cityscapes \ --ckpt_path ./trained_models/cityscapes/r34_NBt1D_half.pth \ --height 512 \ --width 1024 \ --raw_depth \ --context_module appm-1-2-4-8 # 相机:camera1 mIoU:75.22 # 所有相机,mIoU:75.22 # 全分辨率(2048x1024) # 注意,模型是在半分辨率下创建和训练的,只有评估是在全分辨率下进行的 python eval.py \ --dataset cityscapes-with-depth \ --dataset_dir ./datasets/cityscapes \ --ckpt_path ./trained_models/cityscapes/r34_NBt1D_full.pth \ --height 512 \ --width 1024 \ --raw_depth \ --context_module appm-1-2-4-8 \ --valid_full_res # 相机:camera1 mIoU:80.09 # 所有相机,mIoU:80.09
推理
我们提供了对样本输入图像(inference_samples.py
)和从我们使用的数据集中抽取的样本(inference_dataset.py
)进行推理的脚本。
请注意,正确构建模型取决于模型训练所使用的相应数据集。在运行
eval.py
时不传递额外的模型参数,默认会在NYUv2或SUNRGB-D上评估我们的ESANet-R34-NBt1D。对于Cityscapes,参数有所不同。您会在网络权重文件旁边找到一个argsv_*.txt
文件,列出了Cityscapes所需的参数。
数据集推理
使用inference_dataset.py
将训练好的模型应用于从我们使用的数据集中抽取的样本:
示例:要将在SUNRGB-D上训练的ESANet-R34-NBt1D应用于SUNRGB-D的样本,运行:
# 注意,整个第一批次都会被可视化,所以较大的批次大小会导致图中的图像较小
python inference_dataset.py \
--dataset sunrgbd \
--dataset_dir ./datasets/sunrgbd \
--ckpt_path ./trained_models/sunrgbd/r34_NBt1D_scenenet.pth \
--batch_size 4
样本推理
使用inference_samples.py
将训练好的模型应用于./samples
中给出的样本。
注意,需要数据集参数来确定正确的预处理和类别颜色。但是,您不需要准备相应的数据集。此外,根据给定的深度图像和用于训练的数据集,可能需要额外的深度缩放。 示例:
-
要将我们在SUNRGB-D上训练的ESANet-R34-NBt1D应用于样本,请运行:
python inference_samples.py \ --dataset sunrgbd \ --ckpt_path ./trained_models/sunrgbd/r34_NBt1D.pth \ --depth_scale 1 \ --raw_depth
-
要将我们在NYUv2上训练的ESANet-R34-NBt1D应用于样本,请运行:
python inference_samples.py \ --dataset nyuv2 \ --ckpt_path ./trained_models/nyuv2/r34_NBt1D.pth \ --depth_scale 0.1 \ --raw_depth
推理时间
我们在配备Jetpack 4.4的NVIDIA Jetson AGX Xavier上计时推理(TensorRT 7.1.3, PyTorch 1.4.0)。
在配备Jetpack 4.4的NVIDIA Jetson AGX Xavier上复现计时还需要:
- 来自NVIDIA论坛的PyTorch 1.4.0 wheel文件
- NVIDIA TensorRT开源软件(使用
onnx2trt
将onnx模型转换为TensorRT引擎) requirements_jetson.txt
中列出的依赖项:pip3 install -r requirements_jetson.txt --user
随后,您可以运行inference_time.sh
来复现ESANet的报告时间。
可以使用inference_time_whole_model.py
计算单个模型的推理时间。
示例:要复现我们在NYUv2上训练的ESANet-R34-NBt1D的计时,请运行:
python3 inference_time_whole_model.py \
--dataset nyuv2 \
--no_time_pytorch \
--no_time_onnxruntime \
--trt_floatx 16
注意,早于4.4版本的Jetpack可能完全失败或由于上采样处理不同而导致输出结果偏差。
要复现我们论文中比较的其他模型的计时,请按照src/models/external_code中给出的说明进行操作。
训练
使用train.py
在NYUv2、SUNRGB-D、Cityscapes或SceneNet RGB-D上训练ESANet(或按照提供的数据集实现来实现您自己的数据集)。
参数默认为使用我们论文中的超参数在NYUv2上训练ESANet-R34-NBt1D。因此,这些参数可以省略,但为清晰起见在此列出。
请注意,训练ESANet-R34-NBt1D需要编码器主干ResNet-34 NBt1D的预训练权重。您可以从链接下载我们在ImageNet上的预训练权重。否则,您可以使用
imagenet_pretraining.py
创建自己的预训练权重。 示例:
-
在NYUv2上训练我们的ESANet-R34-NBt1D(除数据集参数外,也适用于SUNRGB-D):
# 可以自行指定所有参数 python train.py \ --dataset nyuv2 \ --dataset_dir ./datasets/nyuv2 \ --pretrained_dir ./trained_models/imagenet \ --results_dir ./results \ --height 480 \ --width 640 \ --batch_size 8 \ --batch_size_valid 24 \ --lr 0.01 \ --optimizer SGD \ --class_weighting median_frequency \ --encoder resnet34 \ --encoder_block NonBottleneck1D \ --nr_decoder_blocks 3 \ --modality rgbd \ --encoder_decoder_fusion add \ --context_module ppm \ --decoder_channels_mode decreasing \ --fuse_depth_in_rgb_encoder SE-add \ --upsampling learned-3x3-zeropad # 或使用默认参数 python train.py \ --dataset nyuv2 \ --dataset_dir ./datasets/nyuv2 \ --pretrained_dir ./trained_models/imagenet \ --results_dir ./results
-
在Cityscapes上训练我们的ESANet-R34-NBt1D:
# 注意部分参数有所不同 python train.py \ --dataset cityscapes-with-depth \ --dataset_dir ./datasets/cityscapes \ --pretrained_dir ./trained_models/imagenet \ --results_dir ./results \ --raw_depth \ --he_init \ --aug_scale_min 0.5 \ --aug_scale_max 2.0 \ --valid_full_res \ --height 512 \ --width 1024 \ --batch_size 8 \ --batch_size_valid 16 \ --lr 1e-4 \ --optimizer Adam \ --class_weighting None \ --encoder resnet34 \ --encoder_block NonBottleneck1D \ --nr_decoder_blocks 3 \ --modality rgbd \ --encoder_decoder_fusion add \ --context_module appm-1-2-4-8 \ --decoder_channels_mode decreasing \ --fuse_depth_in_rgb_encoder SE-add \ --upsampling learned-3x3-zeropad
如需更多信息,请使用 python train.py --help
或查看 src/args.py
。
要分析模型结构,请使用相同参数运行
model_to_onnx.py
以导出ONNX模型文件,该文件可以使用 Netron 进行可视化展示。