SimCSE: 句子嵌入的简单对比学习
本仓库包含我们的论文SimCSE: 句子嵌入的简单对比学习的代码和预训练模型。
**************************** 更新 ****************************
- 8/31: 我们的论文已被EMNLP接受!请查看我们的更新版论文(包括更新的数据和基线)。
- 5/12: 我们用新的超参数和更好的性能更新了我们的无监督模型。
- 5/10: 我们发布了我们的句子嵌入工具和演示代码。
- 4/23: 我们发布了我们的训练代码。
- 4/20: 我们发布了我们的模型检查点和评估代码。
- 4/18: 我们发布了我们的论文。查看一下吧!
快速链接
概述
我们提出了一个简单的对比学习框架,可以处理未标记和标记数据。无监督的SimCSE仅在对比学习框架中输入一个句子并预测自身,只有标准的dropout用作噪声。我们的监督SimCSE通过使用NLI数据集中的标注对,将“包含”对作为正样本,将“矛盾”对作为困难负样本,融合到对比学习中。下图展示了我们的模型。
开始使用
我们基于SimCSE模型提供了一个易于使用的句子嵌入工具(详细用法见我们的Wiki)。要使用该工具,首先从PyPI安装simcse
包
pip install simcse
或者直接从我们代码中安装
python setup.py install
请注意,如果您想启用GPU编码,您应该安装支持CUDA的正确版本的PyTorch。安装说明见PyTorch官方网站。
安装包后,只需两行代码即可加载我们的模型
from simcse import SimCSE
model = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased")
有关可用模型的完整列表,请参见模型列表。
然后您可以使用我们的模型将句子编码为嵌入
embeddings = model.encode("A woman is reading.")
计算两组句子之间的余弦相似度
sentences_a = ['A woman is reading.', 'A man is playing a guitar.']
sentences_b = ['He plays guitar.', 'A woman is making a photo.']
similarities = model.similarity(sentences_a, sentences_b)
或者为一组句子构建索引并在其中搜索
sentences = ['A woman is reading.', 'A man is playing a guitar.']
model.build_index(sentences)
results = model.search("He plays guitar.")
我们还支持faiss,一个高效的相似度搜索库。只需按照此处的说明安装包,simcse
将自动使用faiss
进行高效搜索。
警告: 我们发现faiss
不太支持英伟达AMPERE GPU(3090和A100)。在这种情况下,您应更换其他GPU或安装faiss
包的CPU版本。
我们还提供了一个易于构建的演示网站,展示如何在句子检索中使用SimCSE。代码基于DensePhrases的仓库和演示(非常感谢DensePhrases的作者们)。
模型列表
我们发布的模型如下。您可以使用simcse
包或使用HuggingFace's Transformers导入这些模型。
请注意,采用新的超参数集后,结果略好于我们在当前版本的论文中报告的结果(超参数见训练部分)。
命名规则: unsup
和sup
分别表示“无监督”(在维基百科语料库上训练)和“有监督”(在NLI数据集上训练)。
与Huggingface一起使用SimCSE
除了使用我们提供的句子嵌入工具,您还可以通过HuggingFace的transformers
轻松导入我们的模型:
import torch
from scipy.spatial.distance import cosine
from transformers import AutoModel, AutoTokenizer
# 导入我们的模型。该包将自动下载模型
tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
# 对输入文本进行分词
texts = [
"There's a kid on a skateboard.",
"A kid is skateboarding.",
"A kid is inside the house."
]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
# 获取嵌入
with torch.no_grad():
embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
# 计算余弦相似度
# 余弦相似度在[-1, 1]之间。越高表示越相似
cosine_sim_0_1 = 1 - cosine(embeddings[0], embeddings[1])
cosine_sim_0_2 = 1 - cosine(embeddings[0], embeddings[2])
print("句子\"%s\"和\"%s\"的余弦相似度为: %.3f" % (texts[0], texts[1], cosine_sim_0_1))
print("句子\"%s\"和\"%s\"的余弦相似度为: %.3f" % (texts[0], texts[2], cosine_sim_0_2))
如果直接使用HuggingFace的API加载模型时遇到任何问题,也可以从上表手动下载模型并使用model = AutoModel.from_pretrained({PATH TO THE DOWNLOAD MODEL})
。
训练SimCSE
在以下部分中,我们描述了如何使用我们的代码训练SimCSE模型。
要求
首先,请按照官方指南的说明安装PyTorch。为了忠实地再现我们的结果,请使用与您的平台/CUDA版本相对应的1.7.1
版本。更高版本的PyTorch也应能正常工作。例如,如果您使用Linux和CUDA11(如何检查CUDA版本),请通过以下命令安装PyTorch,
pip install torch==1.7.1+cu110 -f https://download.pytorch.org/whl/torch_stable.html
如果您使用CUDA <11
或 CPU,请通过以下命令安装PyTorch,
pip install torch==1.7.1
然后运行以下脚本来安装其余的依赖项,
pip install -r requirements.txt
评估
我们的句子嵌入评估代码基于修改后的SentEval。它在语义文本相似度(STS)任务和下游迁移任务上评估句子嵌入。对于STS任务,我们的评估采用“全”设置,并报告Spearman相关性。评估详情见我们的论文(附录B)。
评估之前,请运行以下命令下载评估数据集
cd SentEval/data/downstream/
bash download_dataset.sh
然后回到根目录,您可以使用我们的评估代码评估任何基于transformers
的预训练模型。例如,
python evaluation.py \
--model_name_or_path princeton-nlp/sup-simcse-bert-base-uncased \
--pooler cls \
--task_set sts \
--mode test
代码会以表格形式输出结果:
------ test ------
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
| STS12 | STS13 | STS14 | STS15 | STS16 | STSBenchmark | SICKRelatedness | Avg. |
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
| 75.30 | 84.67 | 80.19 | 85.40 | 80.82 | 84.26 | 80.39 | 81.58 |
+-------+-------+-------+-------+-------+--------------+-----------------+-------+
评估脚本的参数如下:
--model_name_or_path
:基于transformers
的预训练检查点的名称或路径。你可以直接使用上表中的模型,例如princeton-nlp/sup-simcse-bert-base-uncased
。--pooler
:池化方法。我们目前支持cls
(默认):使用[CLS]
令牌的表示。在表示之后应用一个线性+激活层(这是标准的BERT实现)。如果你使用监督SimCSE,就应该选择这个选项。cls_before_pooler
:使用没有额外线性+激活的[CLS]
令牌的表示。如果你使用无监督SimCSE,就应该选择这个选项。avg
:最后一层的平均嵌入。如果你使用SBERT/SRoBERTa的检查点(论文),应该选择这个选项。avg_top2
:最后两层的平均嵌入。avg_first_last
:第一层和最后一层的平均嵌入。如果你使用原生BERT或RoBERTa,这是最好的选择。
--mode
:评估模式test
(默认):默认测试模式。为了忠实地复现我们的结果,应该选择这个选项。dev
:报告开发集结果。注意在STS任务中,只有STS-B
和SICK-R
有开发集,所以我们只报告这些数据的结果。它还使用传输任务的快速模式,因此运行时间比test
模式短得多(尽管数值稍低)。fasttest
:与test
相同,但采用快速模式,因此运行时间短得多,但报告的数值可能较低(仅适用于传输任务)。
--task_set
:评估时使用的任务集(如果设置,将覆盖--tasks
)sts
(默认):在STS任务上进行评估,包括STS 12~16
,STS-B
和SICK-R
。这是评估句子嵌入质量最常用的任务集。transfer
:评估传输任务。full
:在STS和传输任务上都进行评估。na
:手动通过--tasks
设置任务。
--tasks
:指定要评估的特定数据集。如果--task_set
不为na
,此值将被覆盖。有关任务的完整列表,请参见代码。
训练
数据
对于无监督SimCSE,我们从英文维基百科中抽取了100万句子;对于监督SimCSE,我们使用了SNLI和MNLI数据集。你可以运行data/download_wiki.sh
和data/download_nli.sh
来下载这两个数据集。
训练脚本
我们为无监督和监督SimCSE提供了示例训练脚本。在run_unsup_example.sh
中,我们提供了一个单GPU(或CPU)的无监督版本示例,而在run_sup_example.sh
中,我们给出了一个多GPU的监督版本示例。两个脚本都调用train.py
进行训练。我们在下面解释这些参数:
--train_file
:训练文件路径。我们支持"txt"文件(每行一个句子)和"csv"文件(2列:无硬负样本的配对数据;3列:每对数据有一个对应的硬负样本)。你可以使用我们提供的维基百科或NLI数据,也可以使用你自己格式相同的数据。--model_name_or_path
:用于开始训练的预训练检查点。目前我们支持基于BERT的模型(例如bert-base-uncased
,bert-large-uncased
等)和基于RoBERTa的模型(例如RoBERTa-base
,RoBERTa-large
等)。--temp
:对比损失的温度参数。--pooler_type
:池化方法。这与评估部分中的--pooler_type
相同。--mlp_only_train
:我们发现,对于无监督SimCSE,训练时使用MLP层但测试时不使用效果更好。训练无监督SimCSE模型时应使用这个参数。--hard_negative_weight
:如果使用硬负样本(即训练文件有3列),这是权重的对数。例如,如果权重为1,则此参数应设置为0(默认值)。--do_mlm
:是否使用MLM辅助目标。如果为True:--mlm_weight
:MLM目标的权重。--mlm_probability
:掩码的概率。
所有其他参数都是Huggingface的transformers
训练参数。常用的一些参数有:--output_dir
,--learning_rate
,--per_device_train_batch_size
。在我们的示例脚本中,我们还设置了在STS-B开发集上评估模型(需要按照评估部分下载数据集)并保存最佳检查点。
在论文中的结果,我们使用了Nvidia 3090 GPU和CUDA 11。使用不同类型的设备或不同版本的CUDA/其他软件可能导致性能略有不同。
超参数
我们使用以下超参数来训练SimCSE:
无监督 BERT | 无监督 RoBERTa | 监督 | |
---|---|---|---|
批大小 | 64 | 512 | 512 |
学习率(base) | 3e-5 | 1e-5 | 5e-5 |
学习率(large) | 1e-5 | 3e-5 | 1e-5 |
转换模型
我们保存的检查点与Huggingface的预训练检查点稍有不同。运行python simcse_to_huggingface.py --path {PATH_TO_CHECKPOINT_FOLDER}
进行转换。之后,你可以通过我们的评估代码进行评估,或者直接使用现成的模型。
有问必答
如果你对代码或论文有任何问题,可以随时给Tianyu(tianyug@cs.princeton.edu
)和Xingcheng(yxc18@mails.tsinghua.edu.cn
)发邮件。如果在使用代码时遇到任何问题或想报告一个错误,可以发一个issue。请尽量详细描述问题,以便我们更好更快地帮助你!
引用
如果你在工作中使用SimCSE,请引用我们的论文:
@inproceedings{gao2021simcse,
title={{SimCSE}: Simple Contrastive Learning of Sentence Embeddings},
author={Gao, Tianyu and Yao, Xingcheng and Chen, Danqi},
booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
year={2021}
}
SimCSE的扩展
我们感谢社区对扩展SimCSE所作的努力!
- 苏剑林提供了中文版本的SimCSE。
- AK391在Huggingface Spaces上与Gradio集成。参见演示:
- Nils Reimers实现了基于
sentence-transformers
的SimCSE训练代码。