增强记忆效应的对比学习
关键词: 长尾识别、自监督学习、记忆效应
ICML 2022
@inproceedings{zhou2022contrastive,
title={Contrastive Learning with Boosted Memorization},
author={Zhou, Zhihan and Yao, Jiangchao and Wang, Yan-Feng and Han, Bo and Zhang, Ya},
booktitle={International Conference on Machine Learning},
pages={27367--27377},
year={2022},
organization={PMLR}
}
摘要: 自监督学习在视觉和文本数据的表示学习中取得了巨大成功。然而,当前的方法主要在精心策划的数据集上进行验证,这些数据集并不展现真实世界的长尾分布。最近有尝试从损失角度或模型角度进行重平衡来考虑自监督长尾学习,类似于有监督长尾学习中的范式。然而,由于缺乏标签辅助,这些探索在尾部样本发现或启发式结构设计方面受到限制,未能显示预期的显著前景。与之前的工作不同,我们从另一个角度探索这个方向,即数据角度,并提出了一种新颖的增强对比学习(BCL)方法。具体而言,BCL利用深度神经网络的记忆效应来自动驱动对比学习中样本视图的信息差异,这在无标签背景下更有效地增强长尾学习。在一系列基准数据集上的大量实验表明,BCL相对于几种最先进的方法具有显著优势。
快速入门
环境
- Python (3.7.10)
- Pytorch (1.7.1)
- torchvision (0.8.2)
- CUDA
- Numpy
文件结构
完成准备工作后,整个项目应具有以下结构:
./Boosted-Contrastive-Learning
├── README.md
├── data # 数据集和数据增强
│ ├── memoboosted_cifar100.py
│ ├── cifar100.py
│ ├── augmentations.py
│ └── randaug.py
├── models # 模型和骨干网络
│ ├── simclr.py
│ ├── sdclr.py
│ ├── resnet.py
│ ├── resnet_prune_multibn.py
│ └── utils.py
├── losses # 损失函数
│ └── nt_xent.py
├── split # 数据划分
│ ├── cifar100
│ └── cifar100_imbSub_with_subsets
├── eval_cifar.py # 线性探测评估代码
├── test.py # 测试代码
├── train.py # 训练代码
├── train_sdclr.py # SDCLR训练代码
└── utils.py # 工具函数
代码预览
BCL的代码片段如下所示。
train_datasets = memoboosted_CIFAR100(train_idx_list, args, root=args.data_folder, train=True)
# 初始化动量损失
shadow = torch.zeros(dataset_total_num).cuda()
momentum_loss = torch.zeros(args.epochs,dataset_total_num).cuda()
shadow, momentum_loss = train(train_loader, model, optimizer, scheduler, epoch, log, shadow, momentum_loss, args=args)
train_datasets.update_momentum_weight(momentum_loss, epoch)
在训练阶段,跟踪动量损失。
if epoch>1:
new_average = (1.0 - args.momentum_loss_beta) * loss[batch_idx].clone().detach() + args.momentum_loss_beta * shadow[index[batch_idx]]
else:
new_average = loss[batch_idx].clone().detach()
shadow[index[batch_idx]] = new_average
momentum_loss[epoch-1,index[batch_idx]] = new_average
训练
要在CIFAR-100-LT上训练模型,只需运行:
- SimCLR
python train.py SimCLR --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 5e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy
- BCL-I
python train.py BCL_I --bcl --rand_k 1 --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 5e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy
- SDCLR
python train_sdclr.py SDCLR --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 1e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy
- BCL-D
python train_sdclr.py BCL_D --bcl --rand_k 2 --lr 0.5 --epochs 2000 --temperature 0.2 --weight_decay 1e-4 --data_folder ${data_folder} --trainSplit cifar100_imbSub_with_subsets/cifar100_split1_D_i.npy
预训练的检查点将保存在"checkpoints/"中。
评估
要评估预训练模型,只需运行:
- SimCLR, BCL-I
python test.py --checkpoint ${checkpoint_pretrain} --test_fullshot --test_100shot --test_50shot --data_folder ${data_folder}
- SDCLR, BCL-D
python test.py --checkpoint ${checkpoint_pretrain} --prune --test_fullshot --test_100shot --test_50shot --data_folder ${data_folder}
代码将输出全样本/100样本/50样本线性探测评估的结果。
结果和预训练模型
我们提供在"cifar100_split1_D_i.npy"上预训练的全样本/100样本/50样本结果(演示),以及相应的检查点权重。
方法 | 全样本 | 100样本 | 50样本 | 模型 |
---|---|---|---|---|
SimCLR | 50.7 | 46.3 | 42.4 | ResNet18 |
SDCLR | 55.0 | 49.7 | 45.6 | ResNet18 |
BCL-I | 55.7 | 50.1 | 45.8 | ResNet18 |
BCL-D | 58.7 | 52.6 | 48.7 | ResNet18 |
下载检查点后,您可以按照评估部分的说明运行评估。
扩展
实现自己的模型的步骤
- 将您的模型添加到./models并在train.py中加载模型。
- 在train.py中实现特定于您模型的函数(./losses)。
实现其他数据集的步骤
- 创建数据集的长尾划分并添加到./split。
- 实现数据集(例如memoboosted_cifar100.py)。
致谢
我们借鉴了SDCLR、RandAugment和W-MSE的部分代码。