e5-small-unsupervised 项目介绍
项目概述
e5-small-unsupervised 是一个基于弱监督对比学习预训练的文本嵌入模型。这个模型与其前身 e5-small 很相似,但没有经过监督式微调。e5-small-unsupervised 模型包含了12个层,其嵌入尺寸为384。模型的设计旨在提高文本相似度的计算效率,适用于多个自然语言处理任务。
使用方法
使用 e5-small-unsupervised 模型时,每个输入文本需要以"query: "或"passage: "开头。以下是使用 MS-MARCO passage ranking 数据集对查询和段落进行编码的示例代码:
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
input_texts = [
'query: how much protein should a female eat',
'query: summit define',
"passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day...",
"passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : ..."
]
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-small-unsupervised')
model = AutoModel.from_pretrained('intfloat/e5-small-unsupervised')
batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')
outputs = model(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())
训练细节
关于模型训练的更多细节,请参考以下论文:Text Embeddings by Weakly-Supervised Contrastive Pre-training。
基准测试评估
用户可以访问 unilm/e5 来获取在 BEIR 和 MTEB 基准测试上的评估结果。
支持 Sentence Transformers
e5-small-unsupervised 模型兼容于 Sentence Transformers 框架,以下示例展示了如何在此环境中使用该模型:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('intfloat/e5-small-unsupervised')
input_texts = [
'query: how much protein should a female eat',
'query: summit define',
"passage: As a general guideline, the CDC's average requirement of protein...",
"passage: Definition of summit for English Language Learners. : 1 the highest point..."
]
embeddings = model.encode(input_texts, normalize_embeddings=True)
常见问题
1. 输入文本中需要添加“query:”和“passage:”前缀吗?
是的,模型是在此基础上进行训练的,若不加这些前缀,可能会导致性能下降。
- 在不对称任务(如开放性问答中的段落检索)中,使用"query: "和"passage: "。
- 对于对称任务(如语义相似度、复述检索),使用"query: "前缀。
- 如果要把嵌入当作特征使用(如线性分类、聚类),使用"query: "前缀。
2. 为什么我复现的结果与模型卡中报告的结果稍有不同?
不同版本的 transformers 和 pytorch 可能导致细微但非零的性能差异。
限制
该模型仅对英文文本有效。长文本将被截断至最多512个标记。