SELF-RAG: 通过自我反思学习检索、生成和评判
这包括原始实现的SELF-RAG: 通过自我反思学习检索、生成和评判(ICLR 2024,口头报告前1%),作者为Akari Asai、Zeqiu Wu、Yizhong Wang、Avirup Sil和Hannaneh Hajishirzi。
网站 | 7B模型 | 13B模型 | 论文 | 训练数据 | Twitter摘要 | 更新
Self-RAG(右图)是一个新的框架,用于训练任意语言模型学习检索、生成和评判,以提高生成内容的事实性和质量,同时不影响大型语言模型的多功能性。
与广泛采用的检索增强生成(RAG;左图)方法不同,Self-RAG根据需求进行检索(例如,可以多次检索或完全跳过检索),针对不同的查询,并通过预测反思标记作为生成的组成部分,从多个细粒度方面对自身生成进行评判。我们进行分段束搜索,以选择能最大化多样化偏好效用的输出。
如果您发现我们的代码、数据、模型或论文有用,请引用以下论文:
@inproceedings{
asai2024selfrag,
author={Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
title={Self-{RAG}: Learning to Retrieve, Generate, and Critique through Self-Reflection},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=hSyW5go0v8}
}
更新
- 2023年10月:首次发布代码、模型和论文。
内容
安装
通过运行以下命令安装依赖的Python库。
pip install -r requirements.txt
请使用最新版本的vllm
,因为旧版本可能无法通过SamplingParam
设置skip_special_tokens
,这是由(这个PR)添加的。
您也可以通过运行以下命令创建conda环境。
conda env create -f environment.yml
快速开始
您可以从HuggingFace Hub下载Self-RAG。对于推理,我们建议使用vllm,因为它可以显著加快推理速度。
from vllm import LLM, SamplingParams
model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)
def format_prompt(input, paragraph=None):
prompt = "### 指令:\n{0}\n\n### 回复:\n".format(input)
if paragraph is not None:
prompt += "[检索]<段落>{0}</段落>".format(paragraph)
return prompt
query_1 = "找出不同项:twitter、instagram、whatsapp。"
query_2 = "你能告诉我美洲驼和羊驼的区别吗?"
queries = [query_1, query_2]
# 对于不需要检索的查询
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
print("模型预测: {0}".format(pred.outputs[0].text))
输出:
模型预测: Twitter、Instagram和WhatsApp都是社交媒体平台。[无检索]WhatsApp是不同项,因为它是一个消息应用,而Twitter和Instagram主要用于分享照片和视频。[效用:5]</s>
模型预测: 好的![检索]<段落><段落>
如您所见,在第一个查询中,当不需要检索时,Self-RAG开始生成回答而不进行检索。另一方面,对于第二个查询,Self-RAG输出了[检索]
标记,因为这个问题需要更细粒度的事实依据。
对于需要事实依据的查询,您可以插入一个段落。Self-RAG可以在生成过程中随时检索和插入段落,只要它们被上下文标记特殊标记<段落>
、</段落>
包围,就能识别它们。
# 对于需要事实依据的查询
prompt = format_prompt("你能告诉我美洲驼和羊驼的区别吗?", "羊驼(Lama pacos)是南美洲骆驼科哺乳动物的一种。它与美洲驼相似,常常被混淆。羊驼比美洲驼小得多,与美洲驼不同的是,它们不是被培育为工作动物,而是专门为了它们的纤维而被培育。")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds])
# ['[相关]羊驼比美洲驼小得多,与美洲驼不同的是,它们不是被培育为工作动物,而是专门为了它们的纤维而被培育。[完全支持][效用:5]</s>']
Self-RAG找到相关的插入文档,并生成完全由证据支持的答案。
使用在线检索模型进行评估
您也可以按需进行检索并与Self-RAG一起使用。由于在完整的英文维基百科上运行检索需要大量RAM和多个GPU,我们为演示目的创建了一个只包含维基百科文章介绍段落的子集。
首先,请下载语料库和嵌入(共9GB)。
git clone git@github.com:AkariAsai/self-rag.git
cd retrieval_lm
bash download_demo_corpus.sh
如果脚本不起作用,您可以从Google Drive或HF数据集下载数据。
然后,您可以在retrieval_lm
下运行脚本。我们在1个RTX 6000(24GB)和100G RAM上测试了该脚本(但应该可以在更小的RAM上运行)。
from passage_retrieval import Retriever
retriever = Retriever({})
retriever.setup_retriever_demo("facebook/contriever-msmarco", "enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl", "enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*", n_docs=5, save_or_load_index=False)
retrieved_documents = retriever.search_document_demo(query_3, 5)
prompts = [format_prompt(query_3, doc["title"] +"\n"+ doc["text"]) for doc in retrieved_documents]
preds = model.generate(prompts, sampling_params)
top_doc = retriever.search_document_demo(query_3, 1)[0]
print("参考: {0}\n模型预测: {1}".format(top_doc["title"] + "\n" + top_doc["text"], preds[0].outputs[0].text))
输出:
参考: 过拟合
在统计学中,过拟合是"产生一个与特定数据集过于紧密或完全对应的分析,因此可能无法可靠地适应额外数据或预测未来观察结果"。过拟合模型是一个包含比数据可以证明更多参数的统计模型。过拟合的本质是无意中将一些残差变异(即噪声)提取出来,就好像这种变异代表了潜在的模型结构。欠拟合发生在统计模型无法充分捕捉数据的潜在结构时。欠拟合模型是一个模型,其中一些在正确指定的模型中会出现的参数或项缺失
模型预测: [相关]过拟合发生在模型相对于其训练数据量而言具有太多参数时,导致它过度记忆训练数据,并在新的、未见过的数据上表现不佳。[完全支持][效用:5]</s>
检索系统正确检索了必要的文档并生成了完全有根据的输出。
请注意,此演示使用较小的语料库和具有完整推理算法的Self-RAG。对于完整评估,您需要设置检索器或下载我们的检索结果。请按照推理中的说明进行操作。
检索器设置
默认情况下,我们使用Contriever作为我们的检索组件。
下载数据
下载DPR中使用的预处理段落数据。
cd retrieval_lm
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
然后,下载生成的段落。我们使用Contriever-MSMARCO
wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar
运行检索器
您可以通过运行以下命令来执行段落检索。
cd retrieval_lm
python passage_retrieval.py \
--model_name_or_path facebook/contriever-msmarco --passages psgs_w100.tsv \
--passages_embeddings "wikipedia_embeddings/*" \
--data 你的输入文件 \
--output_dir 你的输出文件 \
--n_docs 20
你的输入文件应该是json
或jsonl
格式。每个实例必须包含question
或instruction
,这将在检索过程中用作查询。
为你自己的数据生成嵌入
你可以通过运行以下命令为你自己的数据生成嵌入。(该脚本改编自Contriever仓库。)请注意,为大规模语料库(>1000万文档)生成嵌入可能需要一些时间,我们建议在多个GPU上运行。
cd retrieval_lm
for i in {0..3}; do
export CUDA_VISIBLE_DEVICES=${i}
python generate_passage_embeddings.py --model_name_or_path facebook/contriever-msmarco \
--output_dir 你的输出目录 \
--passages 你的段落数据 --shard_id ${i} --num_shards 4 > ./log/nohup.my_embeddings.${i} 2>&1 &
训练
Self-RAG训练两个模型,Critic和Generator,它们都扩展了反思标记的词汇表,并使用标准的下一个标记预测目标进行训练。
- 步骤1:Critic数据创建:使用GPT4生成Critic训练数据。
- 步骤2:Critic训练:使用新的特殊标记训练Critic。
- 步骤3:Generator数据创建:使用Critic和检索器生成Generator训练数据。
- 步骤4:Generator训练:使用新的特殊标记训练Generator。
或者,你可以下载我们由15万个实例组成的训练数据这里。
收集反思标记
我们从GPT-4收集训练数据。用于为每种特殊标记类型调用GPT-4的脚本可在data_creation/critic获得。
或者,你可以在这里下载我们的训练数据。
Critic训练
一旦你创建或下载了训练数据,运行下面的命令对Llama2-7B进行critic训练的微调。
cd data_creation
torchrun --nproc_per_node=2 \
--master_port=2568 train_special_tokens.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--data_path 训练数据文件路径 \
--bf16 True \
--output_dir CRITIC模型路径 \
--num_train_epochs 3 \
--per_device_train_batch_size 1 --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 300 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 10 \
--fsdp "full_shard auto_wrap"
Generator数据创建
创建Generator训练数据的代码在generator_data_creation下。请参阅README.md中的说明。
或者,你可以在HuggingFace数据集或这里下载我们的训练数据
Generator训练
对于generator训练,我们使用DeepSpeed来提高训练效率。你可以通过运行下面的脚本来进行训练,设置好训练数据路径后。
cd retrieval_lm
bash script_finetune_7b.sh
对于13B模型训练,使用training_13b
。我们使用8个40GB内存的A100进行7B模型训练,使用4个80GB内存的A100进行13B训练。7B模型应该可以在1-2个A100上运行,尽管训练可能会很慢。
推理
对于论文中进行的任务评估,请在这里下载数据。
每个文件都已经包含了检索到的文档,所以如果你不想在推理过程中运行检索器,你可以简单地在contexts
加载检索到的文档。
下面,我们描述Self-RAG和基线。
短文本生成(PubHealth, ARC-Challenge, TriviaQA, PopQA)
由于我们通常只为短文本生成任务检索一次,我们提供了一个易于运行的评估脚本,利用Contriever离线预先检索的文档。请参见下面的各个命令。
问答
python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/popqa_longtail_w_gs.jsonl \
--mode 模式 --max_new_tokens 100 \
--threshold 0.2 \
--output_file 你的输出文件 \
--metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \
--dtype half
mode
指定推理时的模式,可选择['adaptive_retrieval', 'no_retrieval', 'always_retrieve']
。
adaptive_retrieval
根据threshold
或Self-RAG预测进行检索no_retrieval
在推理时禁用检索always_retrieve
始终进行检索。
对于13B,如果你在单个24GB内存的GPU上运行,可能会遇到内存不足的问题。你可以通过设置--world_size
在多个GPU上运行推理。
ARC Challenge
python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/arc_challenge_processed.jsonl \
--max_new_tokens 50 --threshold 0.2 \
--output_file 输出文件名 \
--metric match --ndocs 5 --use_groundness --use_utility --use_seqscore \
--task arc_c
PubHealth
python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/health_claims_processed.jsonl \
--max_new_tokens 50 \
--threshold 0.2 --output_file 输出文件名 \
--metric match --ndocs 5 \
--use_groundness --use_utility --use_seqscore \
--task fever
长文本生成(ASQA, FactScore)
对于长文本问答,你可以使用检索模型运行评估,也可以使用预先给定的段落运行评估。 目前,我们正在努力减少运行时内存需求(DPR / Contriever与整个英语维基百科嵌入需要100 GB RAM),加快长文本生成的速度,并首先发布使用一小组初始检索文档(~20)的推理代码。
注意:我们当前的实现专门为目标任务数据集的评估而设计。我们计划更新我们的代码库,使接口更简单,更易于使用。当我们发布另一个版本时,我们会宣布。
使用预先检索的段落运行推理
对于ASQA,请运行以下命令,
python run_long_form_static.py \
--model_name selfrag/selfrag_llama2_7b \
--ndocs 5 --max_new_tokens 300 --threshold 0.2 \
--use_grounding --use_utility --use_seqscore \
--task asqa --input_file eval_data/asqa_eval_gtr_top100.json \
--output_file 你的输出文件名 --max_depth 7 --mode always_retrieve \
对于FactScore,
python run_long_form_static.py \
--model_name selfrag/selfrag_llama2_7b \
--ndocs 5 --max_new_tokens 300 --threshold 0.2 \
--use_grounding --use_utility --use_seqscore \
--task factscore --input_file eval_data/factscore_unlabeled_alpaca_13b_retrieval.jsonl \
--output_file 你的输出文件名 --max_depth 7 \
长文本生成的关键参数
Self-RAG的推理有几个关键参数。
w_rel
(默认1.0):w_rel
控制在波束搜索过程中对isRel
(评判检索到的段落是否相关的评价标记)标记概率的强调。w_sup
(默认1.0):w_sup
控制在波束搜索过程中对isSup
(评判生成是否被文档支持的评价标记)标记概率的强调。w_use
(默认0.5):w_use
控制在波束搜索过程中对isUse
(整体质量的评价标记)标记概率的强调。threshold
(默认0.2):此阈值控制自适应检索的频率。max_depth
(默认6):这对应于论文中的T
,它定义了搜索的最大深度。beam_width
(默认2):这控制段级波束搜索中波束的大小。
更多详细信息,请参阅我们论文中的详细说明(第3.3节)和分析(第5节)。
运行评估
对于长文本评估,设置外部库或仓库以运行评估。
factscore==v0.1.5
(生物) 请按照FactScore官方仓库的说明设置你的环境。
python -m factscore.factscorer --data_path 你的输出文件 --model_name retrieval+ChatGPT --cache_dir 你的缓存目录 --openai_key 你的OPEN_AI_密钥 --verbose
ALCE 为长篇问答提供了使用多种不同指标的全面评估。对于您的首次评估,请安装 ALCE 仓库并下载数据。
git clone https://github.com/princeton-nlp/ALCE.git
python3 -m alce_env
cd ALCE
bash download_data.sh
对于 ASQA,您可以按以下方式运行评估。请注意,ASQA 评估需要基于 T5-XXL (11B) 的 NLI 模块。
python eval.py --f YOUR_OUTPUT_FILE --citations --qa --mauve
基准测试
重新运行基准测试的代码可在 run_baseline_lm.py 找到。 要运行检索增强基准测试,请确保下载包含检索段落的任务输入文件。
普通语言模型基准测试
- Huggingface 模型
python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa --prompt_name "prompt_no_input"
例如,PubHealth
python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/health_claims_processed.jsonl \
--max_new_tokens 20 \
--metric accuracy \
--result_fp llama2_7b_pubhealth_results.json \
--task fever
注意:对于 PubHealth 和 ARC,请传入任务名称(ARC = arc_c
和 PubHealth = fever
)以正确设置指令。
- OpenAI API
对于 OpenAI API 模型,您还需要在这里设置组织密钥。您还需要有一个包含 OpenAI API 密钥的 txt 文件。
python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
--task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--prompt_name "prompt_no_input"
检索增强基准测试
- Huggingface 模型
python run_baseline_refactor.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"
- OpenAI API
python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
--task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"
常见问题
问题1:如何使用 Self-RAG 方案训练新的预训练语言模型? -- 如果您使用的是 Hugging Face transformers,您可以简单地在我们的训练脚本 script_finetune_7b.sh 中更改 model_name_or_path
和 tokenizer_name
。如果您想使用自己的微调脚本,请确保添加特殊标记并屏蔽段落上下文,如此问题中所讨论的。
问题2:你们计划发布基于 Mistral-7B 的 Self-RAG 吗? -- 目前我的时间有限,无法这样做,但社区已经训练了一个基于 Mistral-7B 的 Self-RAG 版本 SciPhi-Self-RAG-Mistral-7B-32k。如果我们能够在 Mistral-7B 上训练 Self-RAG 并发布检查点,我们会通知大家。
联系方式
如果您有问题,请提出一个问题并提及 @AkariAsai,或发送电子邮件至 akari[at]cs.washington.edu。