TorchSSL已弃用并不再维护。请参考**USB,这是TorchSSL的升级版本。使用USB训练只需TorchSSL训练时间的12.5%**,并产生更好的结果。
TorchSSL代码的后续修改和合并的拉取请求将不再重新运行以更新结果。
TorchSSL
一个基于Pytorch的半监督学习工具箱。这也是发表在NeurIPS 2021上的FlexMatch: 使用课程伪标签提升半监督学习的官方实现。[arXiv] [知乎文章] [视频]
新闻和更新
2023年1月31日
- 我们添加了[FreeMatch]和[SoftMatch]的代码。请注意,freematch.py指的是没有SAF的FreeMatch,而freematch_entropy.py指的是同时具有SAT和SAF的FreeMatch。训练日志也可以在https://1drv.ms/u/s!AlpW9hcyb0KvmyCfsCjGvhDXG5Nb?e=Xc6amH 找到
2022年8月17日
- TorchSSL(本仓库)不再维护和更新。我们已经创建/更新了一个更全面的半监督学习代码库和基准 - USB。它基于TorchSSL构建,但使用更灵活、更易扩展,包含跨计算机视觉、自然语言处理和音频处理的数据集。
2021年2月15日
- 日志和模型权重已共享!我们注意到一些模型权重缺失。我们将尝试在未来上传缺失的模型权重。
- BestAcc的结果已更新!我对CIFAR-10和SVHN使用单个P100,对STL-10使用单个P40,对CIFAR-100使用单个V100-32G。
介绍
TorchSSL是一个基于PyTorch的全能半监督学习(SSL)工具包。目前,我们实现了9种流行的SSL算法,以实现公平比较并促进SSL算法的发展。
支持的算法: 除了全监督(作为基线)外,TorchSSL还支持以下流行算法:
- PiModel (NeurIPS 2015) [1]
- MeanTeacher (NeurIPS 2017) [2]
- PseudoLabel (ICML 2013) [3]
- VAT (虚拟对抗训练, TPAMI 2018) [4]
- MixMatch (NeurIPS 2019) [5]
- UDA (无监督数据增强, NeurIPS 2020) [6]
- ReMixMatch (ICLR 2019) [7]
- FixMatch (NeurIPS 2020) [8]
- FlexMatch (NeurIPS 2021) [9]
- FreeMatch (ICLR 2023) [10]
- SoftMatch (ICLR 2023) [11]
此外,我们为Pseudo-Label(Flex-Pseudo-Label)和UDA(Flex-UDA)实现了我们的课程伪标签(CPL)方法。
支持的数据集: TorchSSL目前支持SSL研究中的5个流行数据集:
- CIFAR-10
- CIFAR-100
- STL-10
- SVHN
- ImageNet
主要结果
结果是最佳准确率及其标准误差。在结果中,数据集行下的"40"、"250"、"1000"等表示不同的标记样本数量(例如,CIFAR-10中的"40"表示每个类别只有4个标记样本)。我们对所有实验使用随机种子0,1,2。所有配置都包含在config/
文件夹下。您可以在自己的研究中直接引用这些结果。
请注意,FullySupervised结果来自使用数据集中所有训练数据训练模型,不考虑表中标注的标签数量。
CIFAR-10和CIFAR-100
CIFAR-10 | CIFAR100 | ||||||
---|---|---|---|---|---|---|---|
40 | 250 | 4000 | 400 | 2500 | 10000 | ||
FullySupervised | 95.38±0.05 | 95.39±0.04 | 95.38±0.05 | 80.7±0.09 | 80.7±0.09 | 80.73±0.05 | |
PiModel [1] | 25.66±1.76 | 53.76±1.29 | 86.87±0.59 | 13.04±0.8 | 41.2±0.66 | 63.35±0.0 | |
PseudoLabel [3] | 25.39±0.26 | 53.51±2.2 | 84.92±0.19 | 12.55±0.85 | 42.26±0.28 | 63.45±0.24 | |
PseudoLabel_Flex [9] | 26.26±1.96 | 53.86±1.81 | 85.25±0.19 | 14.28±0.46 | 43.88±0.51 | 64.4±0.15 | |
MeanTeacher [2] | 29.91±1.6 | 62.54±3.3 | 91.9±0.21 | 18.89±1.44 | 54.83±1.06 | 68.25±0.23 | |
VAT [4] | 25.34±2.12 | 58.97±1.79 | 89.49±0.12 | 14.8±1.4 | 53.16±0.79 | 67.86±0.19 | |
MixMatch [5] | 63.81±6.48 | 86.37±0.59 | 93.34±0.26 | 32.41±0.66 | 60.24±0.48 | 72.22±0.29 | |
ReMixMatch [7] | 90.12±1.03 | 93.7±0.05 | 95.16±0.01 | 57.25±1.05 | 73.97±0.35 | 79.98±0.27 | |
UDA [6] | 89.38±3.75 | 94.84±0.06 | 95.71±0.07 | 53.61±1.59 | 72.27±0.21 | 77.51±0.23 | |
UDA_Flex [9] | 94.56±0.52 | 94.98±0.07 | 95.76±0.06 | 54.83±1.88 | 72.92±0.15 | 78.09±0.1 | |
FixMatch [8] | 92.53±0.28 | 95.14±0.05 | 95.79±0.08 | 53.58±0.82 | 71.97±0.16 | 77.8±0.12 | |
FlexMatch [9] | 95.03±0.06 | 95.02±0.09 | 95.81±0.01 | 60.06±1.62 | 73.51±0.2 | 78.1±0.15 |
STL-10和SVHN
STL-10 | SVHN | ||||||
---|---|---|---|---|---|---|---|
40 | 250 | 1000 | 40 | 250 | 1000 | ||
FullySupervised | None | None | None | 97.87±0.02 | 97.87±0.01 | 97.86±0.01 | |
PiModel [1] | 25.69±0.85 | 44.87±1.5 | 67.22±0.4 | 32.52±0.95 | 86.7±1.12 | 92.84±0.11 | |
PseudoLabel [3] | 25.32±0.99 | 44.55±2.43 | 67.36±0.71 | 35.39±5.6 | 84.41±0.95 | 90.6±0.32 | |
PseudoLabel_Flex [9] | 26.58±2.19 | 47.94±2.5 | 67.95±0.37 | 36.79±3.64 | 79.58±2.11 | 87.95±0.54 | |
MeanTeacher [2] | 28.28±1.45 | 43.51±2.75 | 66.1±1.37 | 63.91±3.98 | 96.55±0.03 | 96.73±0.05 | |
VAT [4] | 25.26±0.38 | 43.58±1.97 | 62.05±1.12 | 25.25±3.38 | 95.67±0.12 | 95.89±0.2 | |
MixMatch [5] | 45.07±0.96 | 65.48±0.32 | 78.3±0.68 | 69.4±8.39 | 95.44±0.32 | 96.31±0.37 | |
ReMixMatch [7] | 67.88±6.24 | 87.51±1.28 | 93.26±0.14 | 75.96±9.13 | 93.64±0.22 | 94.84±0.31 | |
UDA [6] | 62.58±8.44 | 90.28±1.15 | 93.36±0.17 | 94.88±4.27 | 98.08±0.05 | 98.11±0.01 | |
UDA_Flex [9] | 70.47±2.1 | 90.97±0.45 | 93.9±0.25 | 96.58±1.51 | 97.34±0.83 | 97.98±0.05 | |
FixMatch [8] | 64.03±4.14 | 90.19±1.04 | 93.75±0.33 | 96.19±1.18 | 97.98±0.02 | 98.04±0.03 | |
FlexMatch [9] | 70.85±4.16 | 91.77±0.39 | 94.23±0.18 | 91.81±3.2 | 93.41±2.29 | 93.28±0.3 |
ImageNet
10万标签 | ||
---|---|---|
top-1 | top-5 | |
FixMatch [8] | 56.34 | 78.20 |
FlexMatch [9] | 58.15 | 80.52 |
日志和权重
您可以在这里下载共享的日志和权重。
https://1drv.ms/u/s!AlpW9hcyb0KvmyCfsCjGvhDXG5Nb?e=Xc6amH
使用方法
在运行或修改代码之前,您需要:
- 将此仓库克隆到您的机器上。
- 确保已安装Anaconda或Miniconda。
- 运行
conda env create -f environment.yml
进行环境初始化。
运行实验
使用TorchSSL进行实验非常方便。例如,如果您想运行FlexMatch算法:
- 根据需要修改
config/flexmatch/flexmatch.yaml
中的配置文件 - 运行
python flexmatch.py --c config/flexmatch/flexmatch.yaml
自定义
如果您想编写自己的算法,请按以下步骤操作:
- 为您的算法创建一个目录,例如
SSL
,在其中编写您自己的模型文件SSl/SSL.py
。 - 在
SSL.py
中编写训练文件 - 在
config/SSL/SSL.yaml
中编写配置文件
引用TorchSSL
如果您认为这个工具包或结果对您和您的研究有帮助,请引用我们的论文:
@article{wang2023freematch,
title={FreeMatch: Self-adaptive Thresholding for Semi-supervised Learning},
author={Wang, Yidong and Chen, Hao and Heng, Qiang and Hou, Wenxin and Fan, Yue and and Wu, Zhen and Wang, Jindong and Savvides, Marios and Shinozaki, Takahiro and Raj, Bhiksha and Schiele, Bernt and Xie, Xing},
booktitle={International Conference on Learning Representations (ICLR)},
year={2023}
}
@article{chen2023softmatch,
title={SoftMatch: Addressing the Quantity-Quality Trade-off in Semi-supervised Learning},
author={Chen, Hao and Tao, Ran and Fan, Yue and Wang, Yidong and Wang, Jindong and Schiele, Bernt and Xie, Xing and Raj, Bhiksha and Savvides, Marios},
booktitle={International Conference on Learning Representations (ICLR)},
year={2023}
}
@article{zhang2021flexmatch,
title={FlexMatch: Boosting Semi-supervised Learning with Curriculum Pseudo Labeling},
author={Zhang, Bowen and Wang, Yidong and Hou, Wenxin and Wu, Hao and Wang, Jindong and Okumura, Manabu and Shinozaki, Takahiro},
booktitle={Neural Information Processing Systems (NeurIPS)},
year={2021}
}
维护者
王一东1,陈浩2,范越3,吴昊1,张博文1,侯文欣1,4,陈宇豪5,王晋东4
东京工业大学1
卡内基梅隆大学2
马克斯·普朗克信息学研究所3
微软亚洲研究院4
旷视科技5
贡献
- 欢迎您就错误、问题和建议提出issue。
- 如果您想加入TorchSSL团队,请发送电子邮件给王一东(646842131@qq.com; yidongwang37@gmail.com)以获取更多信息。我们计划添加更多SSL算法,并将TorchSSL从计算机视觉扩展到自然语言处理和语音领域。
声明
对于ImageNet数据集: 请从官方网站下载ImageNet 2014数据集(与2012年相同)(链接:https://image-net.org/challenges/LSVRC/2012/2012-downloads.php)
将训练集和验证集提取到子文件夹中(不使用测试集),并分别将它们放在train/
和val/
下。每个子文件夹将代表一个类别。
注意:官方验证集未按类别压缩到子文件夹中,您可能想使用:https://github.com/jiweibo/ImageNet/blob/master/valprep.sh,这是一个用于准备文件结构的不错脚本。
参考文献
[1] Antti Rasmus, Harri Valpola, Mikko Honkala, Mathias Berglund, and Tapani Raiko. 使用梯度网络进行半监督学习。NeurIPS,第3546-3554页,2015年。
[2] Antti Tarvainen and Harri Valpola. 平均教师是更好的角色模型:权重平均一致性目标改善半监督深度学习结果。NeurIPS,第1195-1204页,2017年。
[3] Dong-Hyun Lee等人。伪标签:深度神经网络的简单高效半监督学习方法。ICML表示学习挑战研讨会,第3卷,2013年。
[4] Takeru Miyato, Shin-ichi Maeda, Masanori Koyama, and Shin Ishii. 虚拟对抗训练:监督和半监督学习的正则化方法。IEEE TPAMI,41(8):1979-1993,2018年。
[5] David Berthelot, Nicholas Carlini, Ian Goodfellow, Nicolas Papernot, Avital Oliver, and Colin Raffel. Mixmatch:半监督学习的整体方法。NeurIPS,第5050-5060页,2019年。
[6] Qizhe Xie, Zihang Dai, Eduard Hovy, Thang Luong, and Quoc Le. 无监督数据增强用于一致性训练。NeurIPS,33,2020年。
[7] David Berthelot, Nicholas Carlini, Ekin D Cubuk, Alex Kurakin, Kihyuk Sohn, Han Zhang, and Colin Raffel. Remixmatch:具有分布匹配和增强锚定的半监督学习。ICLR,2019年。
[8] Kihyuk Sohn, David Berthelot, Nicholas Carlini, Zizhao Zhang, Han Zhang, Colin A Raffel, Ekin Dogus Cubuk, Alexey Kurakin, and Chun-Liang Li. Fixmatch:通过一致性和置信度简化半监督学习。NeurIPS,33,2020年。
[9] Bowen Zhang, Yidong Wang, Wenxin Hou, Hao wu, Jindong Wang, Okumura Manabu, and Shinozaki Takahiro. FlexMatch:通过课程伪标签提升半监督学习。NeurIPS,2021年。
[10] Yidong Wang, Hao Chen, Qiang Heng, Wenxin Hou, Yue Fan, Zhen Wu, Jindong Wang, Marios Savvides, Takahiro Shinozaki, Bhiksha Raj, Bernt Schiele, Xing Xie. FreeMatch:半监督学习的自适应阈值。ICLR,2023年。
[11] Hao Chen, Ran Tao, Yue Fan, Yidong Wang, Marios Savvides, Jindong Wang, Bhiksha Raj, Xing Xie, Bernt Schiele. SoftMatch:解决半监督学习中数量-质量权衡问题。ICLR,2023年。