在线测试时适应
这是一个基于PyTorch的开源在线测试时适应仓库。它是Robert A. Marsden和Mario Döbler的联合作品。这也是以下作品的官方仓库:
- 引入中间域以实现有效的测试时自训练
- 用于持续和渐进测试时适应的鲁棒平均教师 (CVPR2023)
- 通过权重集成、多样性加权和先验校正实现通用测试时适应 (WACV2024)
- 视觉语言模型的失去机会:视觉语言模型在线测试时适应的比较研究 (CVPR2024 MAT研讨会社区赛道)
引用
@article{marsden2022gradual,
title={Gradual test-time adaptation by self-training and style transfer},
author={Marsden, Robert A and D{\"o}bler, Mario and Yang, Bin},
journal={arXiv preprint arXiv:2208.07736},
year={2022}
}
@inproceedings{dobler2023robust,
title={Robust mean teacher for continual and gradual test-time adaptation},
author={D{\"o}bler, Mario and Marsden, Robert A and Yang, Bin},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={7704--7714},
year={2023}
}
@inproceedings{marsden2024universal,
title={Universal Test-time Adaptation through Weight Ensembling, Diversity Weighting, and Prior Correction},
author={Marsden, Robert A and D{\"o}bler, Mario and Yang, Bin},
booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
pages={2555--2565},
year={2024}
}
@article{dobler2024lost,
title={A Lost Opportunity for Vision-Language Models: A Comparative Study of Online Test-time Adaptation for Vision-Language Models},
author={D{\"o}bler, Mario and Marsden, Robert A and Raichle, Tobias and Yang, Bin},
journal={arXiv preprint arXiv:2405.14977},
year={2024}
}
我们欢迎贡献!非常欢迎并感谢添加方法的拉取请求。
先决条件
要使用这个仓库,我们提供了一个conda环境。
conda update conda
conda env create -f environment.yml
conda activate tta
分类
特性
这个仓库包含了一系列不同的方法、数据集、模型和设置,我们在一个全面的基准测试中对其进行了评估(见下文)。我们还提供了一个关于如何将此仓库与CLIP类模型结合使用的教程,可以在这里找到。 以下是仓库主要特性的简要概述:
-
数据集
cifar10_c
CIFAR10-Ccifar100_c
CIFAR100-Cimagenet_c
ImageNet-Cimagenet_a
ImageNet-Aimagenet_r
ImageNet-Rimagenet_v2
ImageNet-V2imagenet_k
ImageNet-Sketchimagenet_d
ImageNet-Dimagenet_d109
domainnet126
DomainNet (清洗后)持续变化的损坏
CCC
-
模型
- 对于适应ImageNet变体,可以使用Torchvision或timm中所有可用的预训练模型。
- 对于损坏基准测试,可以使用RobustBench的预训练模型。
- 对于DomainNet-126基准测试,每个域都有一个预训练模型。
- 其他模型包括ResNet-26 GN。
- 还可以使用OpenCLIP提供的模型。
-
设置
reset_each_shift
适应一个域后重置模型状态。continual
在一系列域上训练模型,不知道域转移何时发生。gradual
在一系列逐渐增加/减少的域转移上训练模型,不知道域转移何时发生。mixed_domains
在一个长测试序列上训练模型,其中连续的测试样本可能来自不同的域。correlated
与持续设置相同,但每个域的样本进一步按类别标签排序。mixed_domains_correlated
混合域并按类别标签排序。- 也可以组合使用,如
gradual_correlated
或reset_each_shift_correlated
。
-
方法
-
混合精度训练
- 除了SAR和GTTA之外,几乎所有上述方法都可以使用混合精度进行训练。这大大加快了实验速度并减少了内存需求。但是,所有基准测试结果都是使用fp32生成的。
-
模块化设计
- 得益于模块化设计,添加新方法应该相当简单。
开始使用
要运行以下基准测试之一,需要下载相应的数据集。
- CIFAR10-to-CIFAR10-C:数据会自动下载。
- CIFAR100-to-CIFAR100-C:数据会自动下载。
- ImageNet-to-ImageNet-C:对于非源自由方法,下载ImageNet和ImageNet-C。
- ImageNet-to-ImageNet-A:对于非源自由方法,下载ImageNet和ImageNet-A。
- ImageNet-to-ImageNet-R:对于非源自由方法,下载ImageNet和ImageNet-R。
- ImageNet-to-ImageNet-V2:对于非源自由方法,下载ImageNet和ImageNet-V2。
- ImageNet-to-ImageNet-Sketch:对于非源自由方法,下载ImageNet和ImageNet-Sketch。
- ImageNet-to-ImageNet-D:对于非源自由方法,下载ImageNet。对于ImageNet-D,请参阅下面的DomainNet-126下载说明。ImageNet-D是通过符号链接创建的,首次使用时会进行设置。
- ImageNet-to-ImageNet-D109:参见下面的DomainNet-126说明。
- DomainNet-126:下载清理版本的6个分割。按照MME的做法,DomainNet-126只使用包含来自4个领域的126个类的子集。
- ImageNet-to-CCC:对于非源自由方法,下载ImageNet。CCC作为网络数据集集成,无需下载!请注意,它不能与相关等设置结合使用。
下载缺失的数据集后,您可能需要调整位于conf.py
文件中的根目录路径_C.DATA_DIR = "./data"
。对于各个数据集,目录名称在conf.py
中以字典形式指定(参见complete_data_dir_path
函数)。如果您的目录名称与映射字典中指定的不同,您可以简单地修改它们。
运行实验
我们为所有实验和方法提供了配置文件。只需使用相应的配置文件运行以下Python文件。
python test_time.py --cfg cfgs/[ccc/cifar10_c/cifar100_c/imagenet_c/imagenet_others/domainnet126]/[source/norm_test/norm_alpha/tent/memo/rpl/eta/eata/rdumb/sar/cotta/rotta/adacontrast/lame/gtta/rmt/roid/tpt].yaml
对于imagenet_others,需要传递CORRUPTION.DATASET
参数:
python test_time.py --cfg cfgs/imagenet_others/[source/norm_test/norm_alpha/tent/memo/rpl/eta/eata/rdumb/sar/cotta/rotta/adacontrast/lame/gtta/rmt/roid/tpt].yaml CORRUPTION.DATASET [imagenet_a/imagenet_r/imagenet_k/imagenet_v2/imagenet_d109]
例如,要运行ROID进行ImageNet-to-ImageNet-R基准测试,请运行以下命令。
python test_time.py --cfg cfgs/imagenet_others/roid.yaml CORRUPTION.DATASET imagenet_r
或者,您可以通过运行classification/scripts
子目录中的run.sh
来重现我们的实验。对于不同的设置,修改run.sh
中的setting
。
要运行不同的连续DomainNet-126序列,您必须传递MODEL.CKPT_PATH
参数。如果不指定CKPT_PATH
,将使用以real域作为源域的序列。这些检查点由AdaContrast提供,可以在这里下载。从结构上讲,最好将它们下载到./ckpt/domainnet126
目录中。
python test_time.py --cfg cfgs/domainnet126/rmt.yaml MODEL.CKPT_PATH ./ckpt/domainnet126/best_clipart_2020.pth
对于GTTA,我们提供了风格转换网络的检查点文件。这些检查点可在
Google-Drive(下载);
将zip文件解压到classification
子目录中。
更改配置
更改评估配置非常简单。例如,要在reset_each_shift
设置下使用ResNet-50和IMAGENET1K_V1
初始化在ImageNet-to-ImageNet-C上运行TENT,需要传递以下参数。
更多模型和初始化可以在这里(torchvision)或这里(timm)找到。
python test_time.py --cfg cfgs/imagenet_c/tent.yaml MODEL.ARCH resnet50 MODEL.WEIGHTS IMAGENET1K_V1 SETTING reset_each_shift
对于ImageNet-C,robustbench提供的默认图像列表每个域考虑5000个样本
(参见这里)。如果你有兴趣在全部
50,000个测试样本上运行实验,只需设置CORRUPTION.NUM_EX 50000
,即
python test_time.py --cfg cfgs/imagenet_c/roid.yaml CORRUPTION.NUM_EX 50000
混合精度
我们支持大多数方法使用损失缩放的自动混合精度更新。
默认情况下混合精度设置为false。要激活混合精度,设置参数MIXED_PRECISION True
。
基准测试
我们在这里提供了每种方法使用不同模型和设置的详细结果, 基准测试会定期更新,随着新方法、数据集或设置添加到仓库中。 关于设置或模型的更多信息也可以在我们的论文中找到。
致谢
- Robustbench 官方
- CoTTA 官方
- TENT 官方
- AdaContrast 官方
- EATA 官方
- LAME 官方
- MEMO 官方
- RoTTA 官方
- SAR 官方
- RDumb 官方
- CMF 官方
- DeYO 官方
- TPT 官方
分割
要运行基于CarlaTTA的实验,你首先需要下载如下提供的数据集分割。同样,你可能需要在conf.py
中更改数据目录_C.DATA_DIR = "./data"
。此外,你需要下载预训练的源检查点(下载)并将zip文件解压到segmentation
子目录中。
例如,要运行GTTA,使用cfgs
目录中提供的配置文件并运行:
python test_time.py --cfg cfgs/gtta.yaml
你也可以通过设置LIST_NAME_TEST
来更改测试序列:
- day2night:
day_night_1200.txt
- clear2fog:
clear_fog_1200.txt
- clear2rain:
clear_rain_1200.txt
- dynamic:
dynamic_1200.txt
- highway:
town04_dynamic_1200.txt
如果你选择highway作为测试序列,你需要更改源列表和相应的检查点路径。
python test_time.py --cfg cfgs/gtta.yaml LIST_NAME_SRC clear_highway_train.txt LIST_NAME_TEST town04_dynamic_1200.txt CKPT_PATH_SEG ./ckpt/clear_highway/ckpt_seg.pth CKPT_PATH_ADAIN_DEC = ./ckpt/clear_highway/ckpt_adain.pth
CarlaTTA
我们在Google-Drive上提供了CarlaTTA的不同数据集作为单独的zip文件:
- clear 下载
- day2night 下载
- clear2fog 下载
- clear2rain 下载
- dynamic 下载
- dynamic-slow 下载
- clear-highway 下载
- highway 下载