项目介绍:all_datasets_v3_mpnet-base
项目背景
all_datasets_v3_mpnet-base是一个利用句子嵌入技术将句子和段落映射到768维密集向量空间的模型。这种技术可以用于句子聚类或语义搜索等任务。该项目旨在通过自监督对比学习目标,在非常大规模的句子级数据集上训练句子嵌入模型。我们使用了预训练的microsoft/mpnet-base
模型,并在包含10亿对句子的数据库上进行了微调。
在Hugging Face组织的使用JAX/Flax进行NLP和计算机视觉的社区周中,我们开发了这一模型。项目得到了谷歌Flax、JAX和云团队成员关于高效深度学习框架的支持,以及七个TPU v3-8的硬件设施。
模型用途
该模型设计为句子与短段落编码器,输入文本将转化为包含语义信息的向量,这种句子向量可以用于信息检索、聚类或句子相似性任务。对于长度超过128个词块的文本,默认会进行截断处理。
训练过程
预训练
我们使用了microsoft/mpnet-base
的预训练模型,详细的预训练过程可以参考对应模型文档。
微调
微调过程中,我们使用了对比学习目标。具体而言,我们计算批次中每个可能的句子对之间的余弦相似度,然后将其与真实对比进行交叉熵损失处理。
超参数
模型是在TPU v3-8上训练的,总共进行了92万步,批次大小为512(每个TPU核心64)。我们使用500步的学习率热身,序列长度限制为128个token,并采用AdamW优化器,学习率为2e-5。完整的训练脚本可以在当前仓库中的train_script.py
中找到。
训练数据
我们通过多数据集的结合来微调模型,句子对总数量超过10亿。每个数据集的抽样采用加权概率进行,配置详细信息见data_config.json
文件。目前使用的一些主要数据集及其对应论文和训练样本数如下:
数据集 | 论文 | 训练样本数 |
---|---|---|
Reddit评论(2015-2018) | 论文 | 726,484,430 |
S2ORC 摘要引用对 | 论文 | 116,288,806 |
WikiAnswers 重复问题对 | 论文 | 77,427,422 |
PAQ 问答对 | 论文 | 64,371,441 |
... | ... | ... |
总计 | - | 1,124,818,467 |
使用方法
使用 Sentence-Transformers
安装sentence-transformers
库之后,通过以下方式加载并使用模型:
pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer
sentences = ["这是一个例句", "每个句子被转换"]
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v1')
embeddings = model.encode(sentences)
print(embeddings)
使用 HuggingFace Transformers
无需sentence-transformers
库,也可以使用该模型。首先,将输入通过transformer模型,然后在上下文化词嵌入上应用合适的池化操作:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
# 平均池化 - 考虑注意力掩码正确求均值
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # model_output的第一个元素包含所有token嵌入
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# 我们想要句子嵌入的句子
sentences = ['这是一个例句', '每个句子被转换']
# 从HuggingFace Hub加载模型
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v1')
model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v1')
# 对句子进行分词
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# 计算token嵌入
with torch.no_grad():
model_output = model(**encoded_input)
# 执行池化
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# 归一化嵌入
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
print("句子嵌入:")
print(sentence_embeddings)
评估结果
有关该模型的自动化评估,请参见句子嵌入基准。