Project Icon

self-rag

通过自反学习使语言模型实现按需检索、生成和评估的框架

Self-RAG是一种创新框架,通过自反学习使语言模型实现按需检索、生成和评估。该方法预测反思标记,支持多次检索或跳过检索,并从多角度评估生成内容。这不仅提高了模型输出的事实性和质量,还保持了语言模型的通用性能。

SELF-RAG: 通过自我反思学习检索、生成和评判

这包括原始实现的SELF-RAG: 通过自我反思学习检索、生成和评判(ICLR 2024,口头报告前1%),作者为Akari Asai、Zeqiu Wu、Yizhong Wang、Avirup Sil和Hannaneh Hajishirzi。

网站 | 7B模型 | 13B模型 | 论文 | 训练数据 | Twitter摘要 | 更新

Self-RAG(右图)是一个新的框架,用于训练任意语言模型学习检索、生成和评判,以提高生成内容的事实性和质量,同时不影响大型语言模型的多功能性。

与广泛采用的检索增强生成(RAG;左图)方法不同,Self-RAG根据需求进行检索(例如,可以多次检索或完全跳过检索),针对不同的查询,并通过预测反思标记作为生成的组成部分,从多个细粒度方面对自身生成进行评判。我们进行分段束搜索,以选择能最大化多样化偏好效用的输出。

如果您发现我们的代码、数据、模型或论文有用,请引用以下论文:

@inproceedings{
asai2024selfrag,
author={Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
title={Self-{RAG}: Learning to Retrieve, Generate, and Critique through Self-Reflection},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=hSyW5go0v8}
}

更新

  • 2023年10月:首次发布代码、模型和论文。

内容

  1. 安装
  2. 快速开始
  3. 检索器设置
  4. 训练
  5. 推理
  6. 基线
  7. 常见问题
  8. 联系方式

安装

通过运行以下命令安装依赖的Python库。

pip install -r requirements.txt

请使用最新版本的vllm,因为旧版本可能无法通过SamplingParam设置skip_special_tokens,这是由(这个PR)添加的。

您也可以通过运行以下命令创建conda环境。

conda env create -f environment.yml

快速开始

您可以从HuggingFace Hub下载Self-RAG。对于推理,我们建议使用vllm,因为它可以显著加快推理速度。

from vllm import LLM, SamplingParams

model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

def format_prompt(input, paragraph=None):
  prompt = "### 指令:\n{0}\n\n### 回复:\n".format(input)
  if paragraph is not None:
    prompt += "[检索]<段落>{0}</段落>".format(paragraph)
  return prompt

query_1 = "找出不同项:twitter、instagram、whatsapp。"
query_2 = "你能告诉我美洲驼和羊驼的区别吗?"
queries = [query_1, query_2]

# 对于不需要检索的查询
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
  print("模型预测: {0}".format(pred.outputs[0].text))

输出:

模型预测: Twitter、Instagram和WhatsApp都是社交媒体平台。[无检索]WhatsApp是不同项,因为它是一个消息应用,而Twitter和Instagram主要用于分享照片和视频。[效用:5]</s>
模型预测: 好的![检索]<段落><段落>

如您所见,在第一个查询中,当不需要检索时,Self-RAG开始生成回答而不进行检索。另一方面,对于第二个查询,Self-RAG输出了[检索]标记,因为这个问题需要更细粒度的事实依据。

对于需要事实依据的查询,您可以插入一个段落。Self-RAG可以在生成过程中随时检索和插入段落,只要它们被上下文标记特殊标记<段落></段落>包围,就能识别它们。

# 对于需要事实依据的查询
prompt = format_prompt("你能告诉我美洲驼和羊驼的区别吗?", "羊驼(Lama pacos)是南美洲骆驼科哺乳动物的一种。它与美洲驼相似,常常被混淆。羊驼比美洲驼小得多,与美洲驼不同的是,它们不是被培育为工作动物,而是专门为了它们的纤维而被培育。")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds])
# ['[相关]羊驼比美洲驼小得多,与美洲驼不同的是,它们不是被培育为工作动物,而是专门为了它们的纤维而被培育。[完全支持][效用:5]</s>']

Self-RAG找到相关的插入文档,并生成完全由证据支持的答案。

使用在线检索模型进行评估

您也可以按需进行检索并与Self-RAG一起使用。由于在完整的英文维基百科上运行检索需要大量RAM和多个GPU,我们为演示目的创建了一个只包含维基百科文章介绍段落的子集。

首先,请下载语料库和嵌入(共9GB)。

git clone git@github.com:AkariAsai/self-rag.git
cd retrieval_lm
bash download_demo_corpus.sh

如果脚本不起作用,您可以从Google DriveHF数据集下载数据。 然后,您可以在retrieval_lm下运行脚本。我们在1个RTX 6000(24GB)和100G RAM上测试了该脚本(但应该可以在更小的RAM上运行)。

from passage_retrieval import Retriever
retriever = Retriever({})
retriever.setup_retriever_demo("facebook/contriever-msmarco", "enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl", "enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*",  n_docs=5, save_or_load_index=False)
retrieved_documents = retriever.search_document_demo(query_3, 5)
prompts = [format_prompt(query_3, doc["title"] +"\n"+ doc["text"]) for doc in retrieved_documents]
preds = model.generate(prompts, sampling_params)
top_doc = retriever.search_document_demo(query_3, 1)[0]
print("参考: {0}\n模型预测: {1}".format(top_doc["title"] + "\n" + top_doc["text"], preds[0].outputs[0].text))

输出:

参考: 过拟合
  在统计学中,过拟合是"产生一个与特定数据集过于紧密或完全对应的分析,因此可能无法可靠地适应额外数据或预测未来观察结果"。过拟合模型是一个包含比数据可以证明更多参数的统计模型。过拟合的本质是无意中将一些残差变异(即噪声)提取出来,就好像这种变异代表了潜在的模型结构。欠拟合发生在统计模型无法充分捕捉数据的潜在结构时。欠拟合模型是一个模型,其中一些在正确指定的模型中会出现的参数或项缺失
模型预测: [相关]过拟合发生在模型相对于其训练数据量而言具有太多参数时,导致它过度记忆训练数据,并在新的、未见过的数据上表现不佳。[完全支持][效用:5]</s>

检索系统正确检索了必要的文档并生成了完全有根据的输出。

请注意,此演示使用较小的语料库和具有完整推理算法的Self-RAG。对于完整评估,您需要设置检索器或下载我们的检索结果。请按照推理中的说明进行操作。

检索器设置

默认情况下,我们使用Contriever作为我们的检索组件。

下载数据

下载DPR中使用的预处理段落数据。

cd retrieval_lm
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz

然后,下载生成的段落。我们使用Contriever-MSMARCO

wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar

运行检索器

您可以通过运行以下命令来执行段落检索。

cd retrieval_lm
python passage_retrieval.py \
    --model_name_or_path facebook/contriever-msmarco --passages psgs_w100.tsv \
    --passages_embeddings "wikipedia_embeddings/*" \
    --data 你的输入文件  \
    --output_dir 你的输出文件 \
    --n_docs 20

你的输入文件应该是jsonjsonl格式。每个实例必须包含questioninstruction,这将在检索过程中用作查询。

为你自己的数据生成嵌入

你可以通过运行以下命令为你自己的数据生成嵌入。(该脚本改编自Contriever仓库。)请注意,为大规模语料库(>1000万文档)生成嵌入可能需要一些时间,我们建议在多个GPU上运行。

cd retrieval_lm
for i in {0..3}; do
  export CUDA_VISIBLE_DEVICES=${i}
  python generate_passage_embeddings.py  --model_name_or_path facebook/contriever-msmarco \
  --output_dir 你的输出目录 \
  --passages 你的段落数据 --shard_id ${i}  --num_shards 4 > ./log/nohup.my_embeddings.${i} 2>&1 &

训练

Self-RAG训练两个模型,CriticGenerator,它们都扩展了反思标记的词汇表,并使用标准的下一个标记预测目标进行训练。

或者,你可以下载我们由15万个实例组成的训练数据这里

收集反思标记

我们从GPT-4收集训练数据。用于为每种特殊标记类型调用GPT-4的脚本可在data_creation/critic获得。

或者,你可以在这里下载我们的训练数据。

Critic训练

一旦你创建或下载了训练数据,运行下面的命令对Llama2-7B进行critic训练的微调。

cd data_creation
torchrun --nproc_per_node=2 \
  --master_port=2568 train_special_tokens.py \
  --model_name_or_path meta-llama/Llama-2-7b-hf \
  --data_path 训练数据文件路径 \
  --bf16  True \
  --output_dir CRITIC模型路径 \
  --num_train_epochs 3  \
  --per_device_train_batch_size 1 --per_device_eval_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --evaluation_strategy "no" \
  --save_strategy "steps" \
  --save_steps 300 \
  --save_total_limit 1 \
  --learning_rate 2e-5 \
  --weight_decay 0. \
  --warmup_ratio 0.01 \
  --lr_scheduler_type "cosine" \
  --logging_steps 10 \
  --fsdp "full_shard auto_wrap"

Generator数据创建

创建Generator训练数据的代码在generator_data_creation下。请参阅README.md中的说明。

或者,你可以在HuggingFace数据集这里下载我们的训练数据

Generator训练

对于generator训练,我们使用DeepSpeed来提高训练效率。你可以通过运行下面的脚本来进行训练,设置好训练数据路径后。

cd retrieval_lm
bash script_finetune_7b.sh

对于13B模型训练,使用training_13b。我们使用8个40GB内存的A100进行7B模型训练,使用4个80GB内存的A100进行13B训练。7B模型应该可以在1-2个A100上运行,尽管训练可能会很慢。

推理

对于论文中进行的任务评估,请在这里下载数据。

每个文件都已经包含了检索到的文档,所以如果你不想在推理过程中运行检索器,你可以简单地在contexts加载检索到的文档。

下面,我们描述Self-RAG和基线。

短文本生成(PubHealth, ARC-Challenge, TriviaQA, PopQA)

由于我们通常只为短文本生成任务检索一次,我们提供了一个易于运行的评估脚本,利用Contriever离线预先检索的文档。请参见下面的各个命令。

问答

python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/popqa_longtail_w_gs.jsonl \
--mode 模式 --max_new_tokens 100 \
--threshold 0.2 \
--output_file 你的输出文件 \
--metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \
--dtype half

mode指定推理时的模式,可选择['adaptive_retrieval', 'no_retrieval', 'always_retrieve']

  • adaptive_retrieval根据threshold或Self-RAG预测进行检索
  • no_retrieval在推理时禁用检索
  • always_retrieve始终进行检索。

对于13B,如果你在单个24GB内存的GPU上运行,可能会遇到内存不足的问题。你可以通过设置--world_size在多个GPU上运行推理。

ARC Challenge

python run_short_form.py \
  --model_name selfrag/selfrag_llama2_7b \
  --input_file eval_data/arc_challenge_processed.jsonl \
  --max_new_tokens 50 --threshold 0.2 \
  --output_file 输出文件名 \
  --metric match --ndocs 5 --use_groundness --use_utility --use_seqscore \
  --task arc_c

PubHealth

python run_short_form.py \
  --model_name selfrag/selfrag_llama2_7b \
  --input_file eval_data/health_claims_processed.jsonl \
  --max_new_tokens 50 \
  --threshold 0.2 --output_file 输出文件名 \
  --metric match --ndocs 5 \
  --use_groundness --use_utility --use_seqscore \
  --task fever

长文本生成(ASQA, FactScore)

对于长文本问答,你可以使用检索模型运行评估,也可以使用预先给定的段落运行评估。 目前,我们正在努力减少运行时内存需求(DPR / Contriever与整个英语维基百科嵌入需要100 GB RAM),加快长文本生成的速度,并首先发布使用一小组初始检索文档(~20)的推理代码。

注意:我们当前的实现专门为目标任务数据集的评估而设计。我们计划更新我们的代码库,使接口更简单,更易于使用。当我们发布另一个版本时,我们会宣布。

使用预先检索的段落运行推理

对于ASQA,请运行以下命令,

python run_long_form_static.py \
  --model_name selfrag/selfrag_llama2_7b \
  --ndocs 5 --max_new_tokens 300 --threshold 0.2 \
  --use_grounding --use_utility --use_seqscore \
  --task asqa --input_file eval_data/asqa_eval_gtr_top100.json \
  --output_file 你的输出文件名 --max_depth 7 --mode always_retrieve \

对于FactScore,

python run_long_form_static.py \
  --model_name selfrag/selfrag_llama2_7b \
  --ndocs 5 --max_new_tokens 300 --threshold 0.2 \
  --use_grounding --use_utility --use_seqscore \
  --task factscore --input_file eval_data/factscore_unlabeled_alpaca_13b_retrieval.jsonl \
  --output_file 你的输出文件名 --max_depth 7 \
长文本生成的关键参数

Self-RAG的推理有几个关键参数。

  • w_rel(默认1.0):w_rel控制在波束搜索过程中对isRel(评判检索到的段落是否相关的评价标记)标记概率的强调。
  • w_sup(默认1.0):w_sup控制在波束搜索过程中对isSup(评判生成是否被文档支持的评价标记)标记概率的强调。
  • w_use(默认0.5):w_use控制在波束搜索过程中对isUse(整体质量的评价标记)标记概率的强调。
  • threshold(默认0.2):此阈值控制自适应检索的频率。
  • max_depth(默认6):这对应于论文中的T,它定义了搜索的最大深度。
  • beam_width(默认2):这控制段级波束搜索中波束的大小。

更多详细信息,请参阅我们论文中的详细说明(第3.3节)和分析(第5节)。

运行评估

对于长文本评估,设置外部库或仓库以运行评估。

  • factscore==v0.1.5(生物) 请按照FactScore官方仓库的说明设置你的环境。
python -m factscore.factscorer --data_path 你的输出文件  --model_name retrieval+ChatGPT --cache_dir 你的缓存目录 --openai_key 你的OPEN_AI_密钥 --verbose

ALCE 为长篇问答提供了使用多种不同指标的全面评估。对于您的首次评估,请安装 ALCE 仓库并下载数据。

git clone https://github.com/princeton-nlp/ALCE.git
python3 -m alce_env
cd ALCE
bash download_data.sh

对于 ASQA,您可以按以下方式运行评估。请注意,ASQA 评估需要基于 T5-XXL (11B) 的 NLI 模块。

python eval.py --f YOUR_OUTPUT_FILE --citations --qa --mauve

基准测试

重新运行基准测试的代码可在 run_baseline_lm.py 找到。 要运行检索增强基准测试,请确保下载包含检索段落的任务输入文件。

普通语言模型基准测试

  • Huggingface 模型
python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
 --max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa --prompt_name "prompt_no_input"

例如,PubHealth

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/health_claims_processed.jsonl \
--max_new_tokens 20 \
--metric accuracy \
--result_fp llama2_7b_pubhealth_results.json \
--task fever

注意:对于 PubHealth 和 ARC,请传入任务名称(ARC = arc_c 和 PubHealth = fever)以正确设置指令。

  • OpenAI API

对于 OpenAI API 模型,您还需要在这里设置组织密钥。您还需要有一个包含 OpenAI API 密钥的 txt 文件。

python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
 --task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--prompt_name "prompt_no_input"

检索增强基准测试

  • Huggingface 模型
python run_baseline_refactor.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
 --max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"
  • OpenAI API
python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
 --task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"

常见问题

问题1:如何使用 Self-RAG 方案训练新的预训练语言模型? -- 如果您使用的是 Hugging Face transformers,您可以简单地在我们的训练脚本 script_finetune_7b.sh 中更改 model_name_or_pathtokenizer_name。如果您想使用自己的微调脚本,请确保添加特殊标记并屏蔽段落上下文,如此问题中所讨论的。

问题2:你们计划发布基于 Mistral-7B 的 Self-RAG 吗? -- 目前我的时间有限,无法这样做,但社区已经训练了一个基于 Mistral-7B 的 Self-RAG 版本 SciPhi-Self-RAG-Mistral-7B-32k。如果我们能够在 Mistral-7B 上训练 Self-RAG 并发布检查点,我们会通知大家。

联系方式

如果您有问题,请提出一个问题并提及 @AkariAsai,或发送电子邮件至 akari[at]cs.washington.edu。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号