multi-qa-mpnet-base-dot-v1项目介绍
项目背景
multi-qa-mpnet-base-dot-v1是一个用于句子嵌入的模型,旨在将句子和段落映射到768维的密集向量空间,特别设计用于语义搜索。其训练数据包含来自多种来源的2.15亿个问题和答案对。该模型开发于 Hugging Face 组织的“Community Week using JAX/Flax for NLP & CV”活动,旨在创建一个卓越的句子嵌入模型。
使用方法
使用Sentence-Transformers
安装必要库:
pip install -U sentence-transformers
使用模型的示例代码:
from sentence_transformers import SentenceTransformer, util
query = "How many people live in London?"
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-dot-v1')
query_emb = model.encode(query)
doc_emb = model.encode(docs)
scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist()
doc_score_pairs = list(zip(docs, scores))
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
for doc, score in doc_score_pairs:
print(score, doc)
使用HuggingFace Transformers
通过HuggingFace Transformers使用模型:
from transformers import AutoTokenizer, AutoModel
import torch
def cls_pooling(model_output):
return model_output.last_hidden_state[:,0]
def encode(texts):
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input, return_dict=True)
embeddings = cls_pooling(model_output)
return embeddings
query = "How many people live in London?"
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1")
model = AutoModel.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1")
query_emb = encode(query)
doc_emb = encode(docs)
scores = torch.mm(query_emb, doc_emb.transpose(0, 1))[0].cpu().tolist()
doc_score_pairs = list(zip(docs, scores))
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
for doc, score in doc_score_pairs:
print(score, doc)
技术细节
- 向量维度:768
- 生成的嵌入是否归一化:否
- 池化方法:CLS池化
- 适用的评分函数:点积(例如
util.dot_score
)
项目用途
multi-qa-mpnet-base-dot-v1模型主要用于语义搜索:在稠密向量空间中对查询/问题和文本段落进行编码,以便找到与给定段落相关的文档。需要注意的是,文本输入有512个词片的限制,超过该长度会被截断,因此不适用于过长的文本。
训练过程
项目利用对比学习目标,对大型句子级数据集进行了训练。训练教程和脚本可在项目仓库中找到,模型使用预训练的mpnet-base
进行初步训练,然后在多个数据集上进行微调,以使其适合更广泛的应用场景。训练数据集包含WikiAnswers, PAQ, Stack Exchange, MS MARCO等一系列数据资源,总共超过2.14亿个训练对。
通过这些努力,multi-qa-mpnet-base-dot-v1实现了在语义搜索任务中的显著性能提升。