简单偏好优化 (SimPO)
本代码库包含我们的论文 SimPO: Simple Preference Optimization with a Reference-Free Reward 的代码和发布的模型。我们提出了一种比 DPO(直接偏好优化)更简单、更有效的偏好优化算法,不需要参考模型。SimPO 在不同设置下,在 AlpacaEval 2、MT-Bench 和 Arena-Hard 基准测试中均优于 DPO 及其最新变体。
🆕 更新日志
- [2024.07.17] 我们发布了一款新的 SimPO 模型 gemma-2-9b-it-SimPO,该模型通过使用 UltraFeedback 数据 在 Google's gemma-2 9B 模型上进行策略优化,由 ArmoRM 注释其偏好数据,在 AlpacaEval 2 中达到了 72.4 的 LC 胜率(#排行榜第一 🎉🎉),在 Arena-Hard 中胜率为 59.1!请在这里查看训练脚本,在这里查看数据生成脚本!
- [2024.07.08] 我们更新了我们的论文(v2)
- 额外基线(RRHF、SLiC-HF、CPO)
- 新的 Llama3-Instruct 配置(v0.2)使用 ArmoRM 作为偏好标签注释器,产生了性能更好的模型 Llama-3-Instruct-8B-SimPO-v0.2,在 AlpacaEval 2 中达到了 53.7 的 LC 胜率,在 Arena-Hard 中胜率为 36.5(训练脚本)!
- 为更好的可复现性更新了 SimPO 训练器。超参数
gamma
改为gamma_beta_ratio
以便于调优。
🔗 快速链接
运行 SimPO 的提示
鉴于对 SimPO 的各种询问,我们提供了一些提示,以帮助您重现我们在论文中的结果,并在您自己的任务中取得更好结果。
环境
我们提供了一个 环境文件,其中包括我们在实验中使用的 Python 包版本。为了达到最佳再现性,我们建议使用相同的包版本。不过,请注意,由于硬件配置和 CUDA 版本的差异,结果可能仍会有所不同。
超参数调优
超参数调优对于 SimPO(和其他偏好优化算法)至关重要。SimPO 的三个主要超参数是 learning_rate
、beta
和 gamma
。
learning_rate
:这是偏好优化中最关键的超参数。过大的学习率(例如 1e-5)会严重降低性能,使模型生成无意义的句子或完全重复的响应。如果资源允许,我们建议在 3e-7、5e-7、8e-7 和 1e-6 之间进行网格搜索。我们发现对于需要高推理能力的领域(如数学),较小的学习率(例如 5e-7)更适合 DPO 和 SimPO。beta
:Beta 控制胜负响应之间的奖励缩放。SimPO 需要比 DPO 大得多的beta
。在我们的预印本中,我们使用了2.0
或2.5
的 beta,但在许多情况下,更大的 beta(例如10
)可能会带来更好的结果。gamma
:Gamma 控制目标奖励边界。我们建议调整 gamma 和 beta 的比率(即gamma / beta
)。我们建议以0.5
为起点,并在0
到1
之间进行网格搜索。调整得当的gamma_beta_ratio
可以提供适度的改进,但它不像其他超参数那样关键。
我们使用以下超参数进行发布模型的训练(注意,在我们的最新更新中,我们将超参数 gamma
改为 gamma_beta_ratio
,因为后者已归一化,并且在不同 beta
值下更易调整)。
设置 | β | γ/β | 学习率 |
---|---|---|---|
Mistral-Base | 2.0 | 0.8 | 3e-7 |
Mistral-Instruct | 2.5 | 0.1 | 5e-7 |
Llama3-Base | 2.0 | 0.5 | 6e-7 |
Llama3-Instruct | 2.5 | 0.55 | 1e-6 |
Llama3-Instruct v0.2 | 10 | 0.3 | 1e-6 |
Gemma | 10 | 0.5 | 8e-7 |
对于 DPO 来说,每种设置的最佳超参数如下。
设置 | β | 学习率 |
---|---|---|
Mistral-Base | 0.01 | 5e-7 |
Mistral-Instruct | 0.01 | 5e-7 |
Llama3-Base | 0.01 | 5e-7 |
Llama3-Instruct | 0.01 | 7e-7 |
Llama3-Instruct v0.2 | 0.01 | 3e-7 |
Gemma | 0.01 | 5e-7 |
在 BOS 中的一致性训练和评估
我们发布的 Llama3 模型使用 Llama3 分词器的初始版本(在此 PR 之前)。我们发现更新后的 Llama3 分词器在 vLLM 中偶尔会引入两个 BOS 标记,这会影响评估结果。因此,请确保在应用 Llama3 聊天模板后,在任何评估中只包括一个 BOS 标记。
特别是,如果你正在训练 Llama3 并在 AlpacaEval 2 和 Arena-Hard 上评估训练过的模型,请使用我们库中提供的模板并确保使用更新前的 Llama3 分词器。
重现 AlpacaEval 2 数字
请确保使用 alpaca-eval==0.6.2
和我们库中的 模型配置 来成功重现 AlpacaEval 2 的结果。因为从 0.6.3
开始,AlpacaEval 在 vllm 解码上有一次重大修订,导致结果与我们的实验不一致。
添加额外的 SFT 损失
CPO_SIMPO 仓库进行了初步实验,并观察到在某些情况下,添加额外的 SFT 损失可以帮助改善结果。在我们自己的实验中,SFT 正则化帮助保留了推理能力(例如 GSM8K),但降低了聊天性能。如果你希望应用 SFT 正则化,可以将 sft_weight
设置为正值(默认情况下为 0)。
发布的模型
Gemma
我们发布了以下两个基于强大的 google/gemma-2-9b-it 模型,通过在策略数据集 princeton-nlp/gemma2-ultrafeedback-armorm 上进行 DPO 和 SimPO 训练而成的模型。对于 GSM 和 MMLU,我们使用 EvalZero 库,该库旨在评估指令调整的 LLMs(即聊天模型而非基础模型)在推理和知识密集型任务上的零次推理性能。更多关于 WildBench 的结果即将发布。
模型 | AE2 LC | AE2 WR | AE2 长度 | AH | AH 长度 | GSM | GSM 长度 | MMLU | MMLU 长度 |
---|---|---|---|---|---|---|---|---|---|
google/gemma-2-9b-it | 51.1 | 38.1 | 1571 | 40.8 | 545 | 87.4 | 395 | 72.7 | 515 |
princeton-nlp/gemma-2-9b-it-DPO | 67.8 | 65.4 | 2016 | 58.9 | 717 | 88.5 | 392 | 72.2 | 624 |
princeton-nlp/gemma-2-9b-it-SimPO | 72.4 | 65.9 | 1833 | 59.1 | 693 | 88.0 | 341 | 72.2 | 441 |
- 与 llama3 模型相比,我们发现 gemma 模型在数学任务(如 GSM)和 MMLU 上表现出显著更少的灾难性遗忘,尽管 ultrafeedback 数据集中只有有限的数学相关数据。这表明 google/gemma-2-9b-it 模型更适合持续的偏好优化。
- SimPO 和 DPO 在所有基准测试中的表现相当,但 SimPO 本质上更简单,资源消耗也更少。
v0.2
我们发现,使用强大的奖励模型来注释偏好优化数据集是至关重要的。在此迭代中,我们重新注释了数据集 princeton-nlp/llama3-ultrafeedback-armorm,使用了更强大的奖励模型 RLHFlow/ArmoRM-Llama3-8B-v0.1。因此,v0.2 模型表现出相比 v0.1 模型显著改善的性能。
注意:我们观察到 SimPO v0.2 模型在生成需要遵循特定结构(如 JSON)的输出时经常遇到困难。这一问题来源于多种因素:llama3-instruct 模型倾向于遗忘以及训练时使用的大学习率(如 1e-6),这导致偏离了原始模型。为了解决这个问题,我们基于 google/gemma-2-9b-it 开发了 SimPO 模型。我们发现,更换初始模型明显减轻了遗忘问题,并减少了学习率的影响。
模型 | AE2 LC | AE2 WR | AH | |
---|---|---|---|---|
Llama 3 Instruct 8B RRHF v0.2 | princeton-nlp/Llama-3-Instruct-8B-RRHF-v2.0 | 37.9 | 31.6 | 28.8 |
Llama 3 Instruct 8B SLiC-HF v0.2 | princeton-nlp/Llama-3-Instruct-8B-SLiC-HF-v2.0 | 33.9 | 32.5 | 29.3 |
Llama 3 Instruct 8B DPO v0.2 | princeton-nlp/Llama-3-Instruct-8B-DPO-v0.2 | 48.2 | 47.5 | 35.2 |
Llama 3 Instruct 8B IPO v0.2 | princeton-nlp/Llama-3-Instruct-8B-IPO-v0.2 | 46.8 | 42.4 | 36.6 |
Llama 3 Instruct 8B CPO v0.2 | princeton-nlp/Llama-3-Instruct-8B-CPO-v0.2 | 34.1 | 36.4 | 30.9 |
Llama 3 Instruct 8B KTO v0.2 | princeton-nlp/Llama-3-Instruct-8B-KTO-v0.2 | 34.1 | 32.1 | 27.3 |
Llama 3 Instruct 8B ORPO v0.2 | princeton-nlp/Llama-3-Instruct-8B-ORPO-v0.2 | 38.1 | 33.8 | 28.2 |
Llama 3 Instruct 8B R-DPO v0.2 | princeton-nlp/Llama-3-Instruct-8B-RDPO-v0.2 | 48.0 | 45.8 | 35.1 |
Llama 3 Instruct 8B SimPO v0.2 | princeton-nlp/Llama-3-Instruct-8B-SimPO-v0.2 | 53.7 | 47.5 | 36.5 |
v0.1
以下是我们预印本中评估的完整模型列表。我们使用HuggingFaceH4/ultrafeedback_binarized 数据集训练Mistral Base 和Llama3 Base模型,使用princeton-nlp/mistral-instruct-ultrafeedback 数据集训练Mistral Instruct模型,使用princeton-nlp/llama3-ultrafeedback 数据集训练Llama3 Instruct模型。后两个数据集由llm-blender/PairRM 模型注释。
使用我们的模型进行推理
请参阅 generate.py 脚本以获取加载具有适当聊天模板的模型的详细说明。
安装要求
我们的代码库是基于 alignment-handbook repo。以下步骤将指导您完成安装过程。
首先,使用如 Conda 创建一个 Python 虚拟环境:
conda create -n handbook python=3.10 && conda activate handbook
接下来,安装 PyTorch v2.2.2
。由于这取决于硬件,我们将您引导至 PyTorch 安装页面。
然后,您可以按如下所示安装 alignment-handbook 的剩余软件包依赖项:
git clone https://github.com/huggingface/alignment-handbook.git
cd ./alignment-handbook/
python -m pip install .
您还需要安装 Flash Attention 2,可以通过运行以下命令来完成:
python -m pip install flash-attn --no-build-isolation
训练脚本
我们提供了四个训练配置文件,对应我们论文中报告的四种训练设置。训练配置设置为4xH100 GPUs。您可能需要根据您的计算环境调整 num_processes
和 per_device_train_batch_size
。
- Mistral-Base:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/mistral-7b-base-simpo.yaml
- Mistral-Instruct:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/mistral-7b-instruct-simpo.yaml
- Llama3-Base:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-base-simpo.yaml
- Llama3-Instruct:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-instruct-simpo.yaml
- Llama3-Instruct v0.2:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml scripts/run_simpo.py training_configs/llama-3-8b-instruct-simpo-v2.yaml
评估
我们遵循官方的实现来评估 AlpacaEval 2, Arena-Hard 和 MT-Bench,如下所示(更多细节可以在 eval 目录 下找到):
-
AlpacaEval 2:请参考 AlpacaEval 仓库 进行评估。
-
Arena-Hard:请参考 Arena-Hard-Auto 仓库 进行评估。
-
MT-Bench:请参考 FastChat 仓库 进行评估。
发现问题或有问题?
如果您有任何关于代码或论文的问题,请随时给 Yu(yumeng5@virginia.edu)发送邮件。如果您在使用代码时遇到任何问题,或者想报告一个 Bug,请随时打开一个 issue!请尽量详细说明问题,以便我们能够更好更快地帮助您!
引用
如果您在工作中发现我们的仓库对您有帮助,请引用我们的论文:
@article{meng2024simpo,
title={{SimPO}: Simple Preference Optimization with a Reference-Free Reward},
author={Meng, Yu and Xia, Mengzhou and Chen, Danqi},
journal={arXiv preprint arXiv:2405.14734},
year={2024}
}