AdaFace:用于人脸识别的自适应质量边界
AdaFace:用于人脸识别的自适应质量边界的官方GitHub仓库。 该论文(https://arxiv.org/abs/2204.00964)在CVPR 2022会议上进行了口头报告。
摘要:低质量人脸数据集的识别具有挑战性,因为面部特征被遮挡和降级。基于边界的损失函数的进展提高了嵌入空间中人脸的可区分性。此外,先前的研究探讨了自适应损失的效果,以赋予错误分类(困难)样本更多的重要性。在本工作中,我们在损失函数中引入了另一个自适应方面,即图像质量。我们认为,强调错误分类样本的策略应根据其图像质量进行调整。具体来说,容易和困难样本的相对重要性应基于样本的图像质量。我们提出了一种新的损失函数,根据样本的图像质量强调不同难度的样本。我们的方法通过用特征范数近似图像质量,以自适应边界函数的形式实现这一点。大量实验表明,我们的方法AdaFace在四个数据集(IJB-B、IJC-C、IJB-S和TinyFace)上改进了人脸识别性能,超越了现有最先进的技术(SoTA)。
@inproceedings{kim2022adaface,
title={AdaFace: Quality Adaptive Margin for Face Recognition},
author={Kim, Minchul and Jain, Anil K and Liu, Xiaoming},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
仓库更新
- CVLFace是AdaFace的新官方仓库(支持各种架构,如ViT、SWIN-ViT、KP-RPE等)。
- 在CVLFace中添加了PartialFC AdaFace实现(以及更多与人脸识别相关的功能,用于开展研究)。
- 测试了与Pytorch Lightning 1.8的兼容性。
- 上传了5分钟视频演示。
- 添加了直接使用InsightFace数据集(train.rec)文件进行训练的选项,无需提取图像。
新闻
- 您还可以查看我们的新论文
KP-RPE: KeyPoint Relative Position Encoding for Face Recognition
论文 视频 代码,用于面部特征点辅助的人脸识别。简而言之:使用面部特征点进行对齐鲁棒性的人脸识别。(在TinyFace、IJB-S上达到SoTA)。 - 您还可以查看我们的新论文
Cluster and Aggregate (CAFace, NeurIPS2022)
链接,用于基于视频的人脸识别。简而言之:长视频探针的人脸识别。
5分钟视频演示
- https://www.youtube.com/watch?v=NfHzn6epAHM
- 该演讲在CVPR 2022会议期间进行。 感谢所有在口头报告和海报展示环节对论文表示兴趣的人。
AdaFace和ArcFace在低质量图像上的演示比较
该演示展示了AdaFace和ArcFace在实时视频上的比较。
为了展示模型在低质量图像上的表现,我们展示了原始
、模糊+
和模糊++
设置,其中
模糊++
表示heavily模糊。
带有颜色框的数字显示了实时图像与最接近匹配的图库图像之间的余弦相似度。
底部的统计数据显示了模糊++
设置下真阳性匹配的累计计数。
AdaFace具有较高的真阳性率。
它还表明它不太容易出现误报(红色)错误,这在ArcFace中有时会观察到。
使用方法
import torch
from head import AdaFace
# 典型的512维输入
B = 5
embbedings = torch.randn((B, 512)).float() # 潜在编码
norms = torch.norm(embbedings, 2, -1, keepdim=True)
normalized_embedding = embbedings / norms
labels = torch.randint(70722, (B,))
# 实例化AdaFace
adaface = AdaFace(embedding_size=512,
classnum=70722,
m=0.4,
h=0.333,
s=64.,
t_alpha=0.01,)
# 计算损失
cosine_with_margin = adaface(normalized_embedding, norms, labels)
loss = torch.nn.CrossEntropyLoss()(cosine_with_margin, labels)
安装
conda create --name adaface pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch
conda activate adaface
conda install scikit-image matplotlib pandas scikit-learn
pip install -r requirements.txt
训练(准备数据集和训练脚本)
- 请参阅README_TRAIN.md
- [重要]请注意,我们的实现假设模型的输入是
BGR
颜色通道,如cv2
包中所使用的。InsightFace模型假设RGB
颜色通道,如PIL
包中所使用的。因此,我们所有的评估代码都使用cv2
包的BGR
颜色通道。
预训练模型
请注意,我们的预训练模型以BGR颜色通道作为输入。 这与InsightFace发布的使用RGB颜色通道的模型不同。
架构 | 数据集 | 链接 |
---|---|---|
R18 | CASIA-WebFace | 谷歌云盘 |
R18 | VGGFace2 | 谷歌云盘 |
R18 | WebFace4M | 谷歌云盘 |
R50 | CASIA-WebFace | 谷歌云盘 |
R50 | WebFace4M | 谷歌云盘 |
R50 | MS1MV2 | 谷歌云盘 |
R100 | MS1MV2 | 谷歌云盘 |
R100 | MS1MV3 | 谷歌云盘 |
R100 | WebFace4M | 谷歌云盘 |
R100 | WebFace12M | 谷歌云盘 |
推理
使用提供的样本图像的示例
AdaFace接受经过预处理的输入图像。 预处理步骤包括:
- 使用面部关键点对齐(使用MTCNN)
- 裁剪为112x112x3大小,颜色通道为BGR顺序。
我们提供了执行预处理步骤的代码。 要使用预训练的AdaFace模型进行推理,请按以下步骤操作:
-
下载预训练的AdaFace模型并将其放置在
pretrained/
目录中 -
要在以下3张图像上使用预训练的AdaFace,请运行
python inference.py
图像1 | 图像2 | 图像3 |
---|---|---|
相似度得分结果应为:
tensor([[ 1.0000, 0.7334, -0.0655],
[ 0.7334, 1.0000, -0.0277],
[-0.0655, -0.0277, 1.0000]], grad_fn=<MmBackward0>)
通用推理指南
简而言之,推理代码如下所示:
from face_alignment import align
from inference import load_pretrained_model, to_input
model = load_pretrained_model('ir_50')
path = '图像路径'
aligned_rgb_img = align.get_aligned_face(path)
bgr_input = to_input(aligned_rgb_img)
feature, _ = model(bgr_input)
- 请注意,AdaFace模型是一个普通的PyTorch模型,它接受
bgr_input
,这是一个112x112x3大小的torch张量,具有BGR颜色通道,其值使用mean=0.5
和std=0.5
进行归一化,如to_input()所示。 - 当预处理步骤出现错误时,可能是因为MTCNN无法在图像中找到人脸。请参考issues/28进行讨论。
验证
高质量图像验证集(LFW、CFPFP、CPLFW、CALFW、AGEDB)
要使用预训练模型在5个高质量图像验证集上进行评估,请参考:
bash validation_hq/eval_5valsets.sh
架构 | 数据集 | 方法 | LFW | CFPFP | CPLFW | CALFW | AGEDB | 平均 |
---|---|---|---|---|---|---|---|---|
R18 | CASIA-WebFace | AdaFace | 0.9913 | 0.9259 | 0.8700 | 0.9265 | 0.9272 | 0.9282 |
R18 | VGGFace2 | AdaFace | 0.9947 | 0.9713 | 0.9172 | 0.9390 | 0.9407 | 0.9526 |
R18 | WebFace4M | AdaFace | 0.9953 | 0.9726 | 0.9228 | 0.9552 | 0.9647 | 0.9621 |
R50 | CASIA-WebFace | AdaFace | 0.9942 | 0.9641 | 0.8997 | 0.9323 | 0.9438 | 0.9468 |
R50 | MS1MV2 | AdaFace | 0.9982 | 0.9786 | 0.9283 | 0.9607 | 0.9785 | 0.9688 |
R50 | WebFace4M | AdaFace | 0.9978 | 0.9897 | 0.9417 | 0.9598 | 0.9778 | 0.9734 |
R100 | MS1MV2 | AdaFace | 0.9982 | 0.9849 | 0.9353 | 0.9608 | 0.9805 | 0.9719 |
R100 | MS1MV3 | AdaFace | 0.9978 | 0.9891 | 0.9393 | 0.9602 | 0.9817 | 0.9736 |
R100 | WebFace4M | AdaFace | 0.9980 | 0.9917 | 0.9463 | 0.9605 | 0.9790 | 0.9751 |
R100 | WebFace12M | AdaFace | 0.9982 | 0.9926 | 0.9457 | 0.9612 | 0.9800 | 0.9755 |
与其他方法的比较
架构 | 数据集 | 方法 | LFW | CFPFP | CPLFW | CALFW | AGEDB | 平均 |
---|---|---|---|---|---|---|---|---|
R50 | CASIA-WebFace | AdaFace | 0.9942 | 0.9641 | 0.8997 | 0.9323 | 0.9438 | 0.9468 |
R50 | CASIA-WebFace | (ArcFace) | 0.9945 | 0.9521 | NA | NA | 0.9490 | NA |
R100 | MS1MV2 | AdaFace | 0.9982 | 0.9849 | 0.9353 | 0.9608 | 0.9805 | 0.9719 |
R100 | MS1MV2 | (ArcFace) | 0.9982 | NA | 0.9208 | 0.9545 | NA | NA |
混合质量场景(IJBB、IJBC数据集)
对于IJBB和IJBC验证,请参考:
cd validation_mixed
bash eval_ijb.sh
架构 | 数据集 | 方法 | IJBB TAR@FAR=0.01% | IJBC TAR@FAR=0.01% |
---|---|---|---|---|
R18 | VGG2 | AdaFace | 90.67 | 92.95 |
R18 | WebFace4M | AdaFace | 93.03 | 94.99 |
R50 | WebFace4M | AdaFace | 95.44 | 96.98 |
R50 | MS1MV2 | AdaFace | 94.82 | 96.27 |
R100 | MS1MV2 | AdaFace | 95.67 | 96.89 |
R100 | MS1MV3 | AdaFace | 95.84 | 97.09 |
R100 | WebFace4M | AdaFace | 96.03 | 97.39 |
R100 | WebFace12M | AdaFace | 96.41 | 97.66 |
与其他方法的比较
- 其他方法的数据来自各自的论文。
架构 | 数据集 | 方法 | 会议 | IJBB TAR@FAR=0.01% | IJBC TAR@FAR=0.01% |
---|---|---|---|---|---|
R100 | MS1MV2 | AdaFace | CVPR22 | 95.67 | 96.89 |
R100 | MS1MV2 | (MagFace) | CVPR21 | 94.51 | 95.97 |
R100 | MS1MV2 | (SCF-ArcFace) | CVPR21 | 94.74 | 96.09 |
R100 | MS1MV2 | (BroadFace) | ECCV20 | 94.97 | 96.38 |
R100 | MS1MV2 | (CurricularFace) | CVPR20 | 94.80 | 96.10 |
R100 | MS1MV2 | (MV-Softmax) | AAAI20 | 93.60 | 95.20 |
R100 | MS1MV2 | (AFRN) | ICCV19 | 88.50 | 93.00 |
R100 | MS1MV2 | (ArcFace) | CVPR19 | 94.25 | 96.03 |
R100 | MS1MV2 | (CosFace) | CVPR18 | 94.80 | 96.37 |
架构 | 数据集 | 方法 | IJBC TAR@FAR=0.01% |
---|---|---|---|
R100 | WebFace4M | AdaFace | 97.39 |
R100 | WebFace4M | (CosFace) | 96.86 |
R100 | WebFace4M | (ArcFace) | 96.77 |
R100 | WebFace4M | (CurricularFace) | 97.02 |
架构 | 数据集 | 方法 | IJBC TAR@FAR=0.01% |
---|---|---|---|
R100 | WebFace12M | AdaFace | 97.66 |
R100 | WebFace12M | (CosFace) | 97.41 |
R100 | WebFace12M | (ArcFace) | 97.47 |
R100 | WebFace12M | (CurricularFace) | 97.51 |
低质量场景 (IJBS)
对于IJBB、IJBC验证,请参考
cd validation_lq
python validate_IJB_S.py
与其他方法的比较
监控到单张 | 监控到簿册 | 监控到监控 | 小脸 | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
架构 | 方法 | 数据集 | 排名1 | 排名5 | 1% | 排名1 | 排名5 | 1% | 排名1 | 排名5 | 1% | 排名1 | 排名5 |
R100 | AdaFace | WebFace12M | 71.35 | 76.24 | 59.39 | 71.93 | 76.56 | 59.37 | 36.71 | 50.03 | 4.62 | 72.29 | 74.97 |
R100 | AdaFace | WebFace4M | 70.42 | 75.29 | 58.27 | 70.93 | 76.11 | 58.02 | 35.05 | 48.22 | 4.96 | 72.02 | 74.52 |
R100 | AdaFace | MS1MV2 | 65.26 | 70.53 | 51.66 | 66.27 | 71.61 | 50.87 | 23.74 | 37.47 | 2.50 | 68.21 | 71.54 |
R100 | (CurricularFace) | MS1MV2 | 62.43 | 68.68 | 47.68 | 63.81 | 69.74 | 47.57 | 19.54 | 32.80 | 2.53 | 63.68 | 67.65 |
R100 | (URL) | MS1MV2 | 58.94 | 65.48 | 37.57 | 61.98 | 67.12 | 42.73 | NA | NA | NA | 63.89 | 68.67 |
R100 | (ArcFace) | MS1MV2 | 57.35 | 64.42 | 41.85 | 57.36 | 64.95 | 41.23 | NA | NA | NA | NA | NA |
R100 | (PFE) | MS1MV2 | 50.16 | 58.33 | 31.88 | 53.60 | 61.75 | 35.99 | 9.20 | 20.82 | 0.84 | NA | NA |
- 监控到单张:比较监控视频(探针)与单张注册图像(画廊)的协议
- 监控到簿册:比较监控视频(探针)与所有注册图像(画廊)的协议
- 监控到监控:比较监控视频(探针)与监控视频(画廊)的协议