2023年11月更新:SePiCo被选为 :trophy: ESI高被引论文!!
2023年2月15日更新:发布Cityscapes → Dark Zurich的代码。
2023年1月14日更新:🥳 我们很高兴地宣布SePiCo已被TPAMI接收并将在即将出版的一期中发表。
2022年9月24日更新:所有检查点均已可用。
2022年9月4日更新:代码发布。
2022年4月20日更新:SePiCo的ArXiv版本已发布。
概述
在这项工作中,我们提出了语义引导的像素对比学习(SePiCo),这是一种新颖的单阶段适应框架,它突出了单个像素的语义概念,以促进跨域类别区分性和类别平衡的像素嵌入空间的学习,最终提升自训练方法的性能。
安装
本代码使用Python 3.8.5
和PyTorch 1.7.1
在CUDA 11.0
上实现。
要尝试这个项目,建议先设置一个虚拟环境:
# 创建并激活环境
conda create --name sepico -y python=3.8.5
conda activate sepico
# 为新的Python环境安装正确的pip和依赖项
conda install -y ipython pip
然后,可以通过以下方式安装依赖项:
# 安装所需的包
pip install -r requirements.txt
# 安装mmcv-full,此命令在本地编译mmcv,可能需要一些时间
pip install mmcv-full==1.3.7 # 需要先安装其他包
或者,可以使用官方预构建的包更快地安装mmcv-full
,例如:
# 另一种安装mmcv-full的方法,更快
pip install mmcv-full==1.3.7 -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.html
现在环境已经完全准备好了。
数据集准备
下载数据集
- GTAV: 从这里下载所有压缩的图像及其压缩的标签,并将它们解压到自定义目录。
- Cityscapes: 从这里下载leftImg8bit_trainvaltest.zip和gtFine_trainvaltest.zip,并将它们解压到自定义目录。
- Dark Zurich: 从这里下载Dark_Zurich_train_anon.zip、Dark_Zurich_val_anon.zip和Dark_Zurich_test_anon_withoutGt.zip,并将它们解压到自定义目录。
设置数据集
创建所需数据集的符号链接:
ln -s /path/to/gta5/dataset data/gta
ln -s /path/to/cityscapes/dataset data/cityscapes
ln -s /path/to/dark_zurich/dataset data/dark_zurich
进行预处理以将标签ID转换为训练ID并收集数据集统计信息:
python tools/convert_datasets/gta.py data/gta --nproc 8
python tools/convert_datasets/cityscapes.py data/cityscapes --nproc 8
最终,数据结构应该如下所示:
SePiCo
├── ...
├── data
│ ├── cityscapes
│ │ ├── gtFine
│ │ ├── leftImg8bit
│ ├── dark_zurich
│ │ ├── corresp
│ │ ├── gt
│ │ ├── rgb_anon
│ ├── gta
│ │ ├── images
│ │ ├── labels
├── ...
模型库
我们通过Google Drive和百度网盘(访问码:pico
)提供了两个域适应语义分割任务的预训练模型。
GTAV → Cityscapes (基于DeepLab-v2)
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_gta2city_dlv2.pth | 61.0 | Google / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_gta2city_dlv2.pth | 59.8 | Google / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_gta2city_dlv2.pth | 58.8 | Google / 百度 (提取码: pico ) |
GTAV → Cityscapes (基于DAFormer)
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_gta2city_daformer.pth | 70.3 | Google / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_gta2city_daformer.pth | 68.7 | Google / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_gta2city_daformer.pth | 68.5 | Google / 百度 (提取码: pico ) |
SYNTHIA → Cityscapes (基于DeepLab-v2)
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_syn2city_dlv2.pth | 58.1 | Google / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_syn2city_dlv2.pth | 57.4 | Google / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_syn2city_dlv2.pth | 56.8 | Google / 百度 (提取码: pico ) |
SYNTHIA → Cityscapes (基于DAFormer)
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_syn2city_daformer.pth | 64.3 | 谷歌 / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_syn2city_daformer.pth | 63.3 | 谷歌 / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_syn2city_daformer.pth | 62.9 | 谷歌 / 百度 (提取码: pico ) |
Cityscapes → Dark Zurich (基于DeepLab-v2)
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_city2dark_dlv2.pth | 45.4 | 谷歌 / 百度 (提取码: pico ) |
BankCL | sepico_bankcl_city2dark_dlv2.pth | 44.1 | 谷歌 / 百度 (提取码: pico ) |
ProtoCL | sepico_protocl_city2dark_dlv2.pth | 42.6 | 谷歌 / 百度 (提取码: pico ) |
Cityscapes → Dark Zurich (基于DAFormer)
变体 | 模型名称 | mIoU | 检查点下载 |
---|---|---|---|
DistCL | sepico_distcl_city2dark_daformer.pth | 54.2 | 谷歌 / 百度 (提取码: pico ) |
BankCL | sepico_distcl_city2dark_daformer.pth | 53.3 | 谷歌 / 百度 (提取码: pico ) |
ProtoCL | sepico_distcl_city2dark_daformer.pth | 52.7 | 谷歌 / 百度 (提取码: pico ) |
我们训练的模型(sepico_distcl_city2dark_daformer.pth)也在Nighttime Driving和BDD100k-night测试集上进行了泛化性能测试。
方法 | 模型名称 | Dark Zurich-test | Nighttime Driving | BDD100k-night | 检查点下载 |
---|---|---|---|---|---|
SePiCo | sepico_distcl_city2dark_daformer.pth | 54.2 | 56.9 | 40.6 | 谷歌 / 百度 (提取码: pico ) |
SePiCo评估
在Cityscapes上评估
要在Cityscapes上评估预训练模型,请按如下方式运行:
python -m tools.test /path/to/config /path/to/checkpoint --eval mIoU
示例
例如,如果您将sepico_distcl_gta2city_dlv2.pth
及其配置JSON文件sepico_distcl_gta2city_dlv2.json
下载到文件夹./checkpoints/sepico_distcl_gta2city_dlv2/
中,那么评估脚本应该如下所示:
python -m tools.test ./checkpoints/sepico_distcl_gta2city_dlv2/sepico_distcl_gta2city_dlv2.json ./checkpoints/sepico_distcl_gta2city_dlv2/sepico_distcl_gta2city_dlv2.pth --eval mIoU
在Dark Zurich上评估
要在Dark Zurich上进行评估,请按如下方式获取标签预测结果,然后将其提交到官方的测试服务器。
在本地获取测试集的标签预测:
python -m tools.test /path/to/config /path/to/checkpoint --format-only --eval-options imgfile_prefix=/path/to/labelTrainIds
示例
例如,如果你将 `sepico_distcl_city2dark_daformer.pth` 及其配置 JSON 文件 `sepico_distcl_city2dark_daformer.json` 下载到文件夹 `./checkpoints/sepico_distcl_city2dark_daformer/` 中,那么评估脚本应该如下所示:python -m tools.test ./checkpoints/sepico_distcl_city2dark_daformer/sepico_distcl_city2dark_daformer.json ./checkpoints/sepico_distcl_city2dark_daformer/sepico_distcl_city2dark_daformer.pth --format-only --eval-options imgfile_prefix=dark_test/distcl_daformer/labelTrainIds
请注意,测试服务器只接受具有以下目录结构的提交:
submit.zip
├── confidence
├── labelTrainIds
├── labelTrainIds_invalid
因此,我们需要手动构建 confidence
和 labelTrainIds_invalid
目录(因为它们对 SePiCo 评估并非必需)。
以下是我们的参考做法(请参考上面的示例中的目录名):
cd dark_test/distcl_daformer
cp -r labelTrainIds labelTrainIds_invalid
cp -r labelTrainIds confidence
zip -q -r sepico_distcl_city2dark_daformer.zip labelTrainIds labelTrainIds_invalid confidence
# 现在将 sepico_distcl_city2dark_daformer.zip 提交到测试服务器以获取结果。
SePiCo 训练
首先,从这里下载 SegFormer 官方在 ImageNet-1k 上预训练的 MiT-B5 权重(即 mit_b5.pth
),并将其放入新文件夹 ./pretrained
中。
训练入口在 run_experiments.py
。要查看特定任务的设置,请查看 experiments.py
以获取更多详细信息。通常,训练脚本如下:
python run_experiments.py --exp <exp_id>
任务 1~6 在 GTAV → Cityscapes 上运行,<exp_id>
和任务的映射关系如下:
<exp_id> | 变体 | 骨干网络 | 特征 |
---|---|---|---|
1 | DistCL | ResNet-101 | layer-4 |
2 | BankCL | ResNet-101 | layer-4 |
3 | ProtoCL | ResNet-101 | layer-4 |
4 | DistCL | MiT-B5 | all-fusion |
5 | BankCL | MiT-B5 | all-fusion |
6 | ProtoCL | MiT-B5 | all-fusion |
任务 7~8 在 Cityscapes → Dark Zurich 上运行,<exp_id>
和任务的映射关系如下:
<exp_id> | 变体 | 骨干网络 | 特征 |
---|---|---|---|
7 | DistCL | ResNet-101 | layer-4 |
8 | DistCL | MiT-B5 | all-fusion |
训练完成后,可以按照 SePiCo 评估 进行模型测试。请注意,训练结果位于 ./work_dirs
中。配置文件名应类似:220827_1906_dlv2_proj_r101v1c_sepico_DistCL-reg-w1.0-start-iter3000-tau100.0-l3-w1.0_rcs0.01_cpl_self_adamw_6e-05_pmT_poly10warm_1x2_40k_gta2cs_seed76_4cc9a.json
,模型文件后缀为 .pth
。
代码理解提示
- 类平衡裁剪(CBC)策略在 mmseg/models/utils/ours_transforms.py 中以
RandomCrop
类的形式实现。 - 投影头可以在 mmseg/models/decode_heads/proj_head.py 中找到。
- 用于特征存储的语义原型在 mmseg/models/utils/proto_estimator.py 中实现,其中包括三种变体的原型。详细用法请参考
mmseg/models/uda/sepico.py
。 - 与我们框架的三种变体相对应的损失函数,以及正则化项,在 mmseg/models/losses/contrastive_loss.py 中实现。
致谢
本项目基于以下开源项目。我们感谢这些项目的作者公开源代码。
- MMSegmentation(Apache 许可证 2.0,许可详情)
- SegFormer(NVIDIA 源代码许可证,许可详情)
- DAFormer(Apache 许可证 2.0,许可详情)
- DACS(MIT 许可证,许可详情)
- DANNet(Apache 许可证 2.0,许可详情)
引用
如果您觉得我们的工作有帮助,请为本仓库点星🌟并引用📑我们的论文。感谢您的支持!
@article{xie2023sepico,
title={Sepico: Semantic-guided pixel contrast for domain adaptive semantic segmentation},
author={Xie, Binhui and Li, Shuang and Li, Mingjia and Liu, Chi Harold and Huang, Gao and Wang, Guoren},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2023},
publisher={IEEE}
}
联系
如需帮助或与 SePiCo 相关的问题,或报告 bug,请开启一个 [GitHub Issues],或随时联系 binhuixie@bit.edu.cn。