OLM RoBERTa/BERT 2022年12月版项目介绍
OLM RoBERTa/BERT 2022年12月版是原始BERT和RoBERTa的更新版本。在保留原模型优势的同时,其在标准基准测试中表现得更为出色。尽管与原始RoBERTa相比也有一些不同的表现,但在许多标准基准测试中,性能差距并不大。该模型的训练数据来源于清理后的2022年12月Common Crawl和Wikipedia数据集。
OLM项目的目标是不断训练和发布与最新动态保持一致并且在标准语言模型性能上可以媲美静态模型的模型。这对于模型能够即时获取COVID或总统大选等事件的最新信息十分重要。
预期用途
这个模型可以用于掩码语言建模,不过主要用于微调并应用在下游任务,比如序列分类、标注分类或问答系统。
如何使用
用户可以通过以下代码实现掩码语言建模的功能:
from transformers import pipeline
unmasker = pipeline('fill-mask', model='olm/olm-roberta-base-dec-2022')
unmasker("Hello I'm a <mask> model.")
此外,还可以通过以下代码获取给定文本的特征:
from transformers import AutoTokenizer, RobertaModel
tokenizer = AutoTokenizer.from_pretrained('olm/olm-roberta-base-dec-2022')
model = RobertaModel.from_pretrained("olm/olm-roberta-base-dec-2022")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
数据集
该模型和分词模型的训练采用的是2022年12月清理后的Common Crawl和Wikipedia数据集。相关的数据集可以在此处获得。这些数据集是通过Hugging Face的相关仓库创建的。
训练
该模型的训练流程遵循OLM BERT/RoBERTa的指导方针,具体可查看这个仓库。
评估结果
该模型在调优GLUE任务后取得如下结果:
任务 | 指标 | 原始BERT | OLM RoBERTa Dec 2022(我们的模型) |
---|---|---|---|
cola | mcc | 0.5889 | 0.28067 |
sst2 | acc | 0.9181 | 0.9275 |
mrpc | acc/f1 | 0.9182/0.8923 | 0.8662/0.9033 |
stsb | pear/spear | 0.8822/0.8794 | 0.8870/0.8857 |
qqp | acc/f1 | 0.9071/0.8748 | 0.9097/0.8791 |
mnli | acc/acc_mm | 0.8400/0.8410 | 0.8576/0.8621 |
qnli | acc | 0.9075 | 0.9192 |
rte | acc | 0.6296 | 0.6390 |
wnli | acc | 0.4000 | 0.4648 |
以上结果是使用Hugging Face的run_glue.py
脚本运行得到的,训练和评估采用了默认的微调超参数,并且结果取自五个训练种子的平均值。这些结果来自GLUE开发集,可能与测试集结果稍有不同。