DPO: 直接偏好优化
新增功能: 除了原始的 DPO 算法外,本代码库现在还支持'保守式' DPO和IPO。
对于保守式 DPO,你只需在进行 DPO 训练时额外传递参数 loss.label_smoothing=X
,其中 X 是介于 0 和 0.5 之间的值(0 表示原始 DPO 损失)。这个参数本质上是保守程度参数,即训练偏好数据中不正确(偏好方向相反)的比例。从 0.1 左右开始可能比较合理,但我还没有测试过(而且这将取决于偏好数据集)。
对于 IPO,只需传递 loss=ipo
和 loss.beta=X
,其中 X 是一个非负值(与 DPO/保守式 DPO 相同)。
这个代码库是什么?
本代码库包含了 DPO 算法的参考实现,用于从偏好数据训练语言模型,如论文直接偏好优化:你的语言模型其实是一个奖励模型中所述。
这里的代码支持任何因果 HuggingFace 模型 - 查看我们在 config/model
中的示例来添加你自己的模型。添加自己的数据集也很容易。请参阅README 部分了解如何添加数据集。
DPO 流程有两个阶段:
- 对感兴趣的数据集进行监督微调(SFT)。
- 使用偏好数据(最好与 SFT 示例来自相同分布)对第 1 步的模型进行偏好学习。
本代码库中的文件包括:
train.py
:训练的主入口点(用于 SFT 或基于偏好的 DPO 训练)trainers.py
:训练器类(例如,实现学习循环以及多 GPU 逻辑)utils.py
:多个其他文件使用的一些便捷函数preference_datasets.py
:用于 SFT 和基于偏好的 DPO 训练的数据集处理逻辑;如果你想用自己的数据进行训练,需要在这里进行一些添加
运行 SFT
对于 DPO,SFT 阶段基本上确保我们训练的偏好数据在实际进行偏好学习之前对我们的策略来说是同分布的。
在 Anthropic-HH 数据上用批量大小 64 运行 Pythia 6.9B 的 SFT:
python -u train.py model=pythia69 datasets=[hh] loss=sft exp_name=anthropic_dpo_pythia69 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false
在 Anthropic-HH + Stanford Human Preference 数据上用批量大小 64 运行自定义模型(例如,本地路径的 Llama)的 SFT:
python -u train.py model=blank_model model.name_or_path=/PATH/TO/LLAMA/WEIGHTS model.block_name=LlamaDecoderLayer datasets=[hh,shp] loss=sft exp_name=anthropic_shp_sft_llama_7b gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false
注意:由于我们没有使用预定义的模型配置,我们还需要传递
model.block_name
来告诉 FSDP 要包装哪些模块。
默认情况下,每 20k 个样本进行一次评估。你可以通过 eval_every
参数更改这个设置。如果你不传递 sample_during_eval=false
,每次评估时也会进行采样。
要运行不同的模型,可以在 config/model
中添加新的模型配置,或者使用 blank_model
选项作为 model
,并显式传递 model.name_or_path
(如果使用 FSDP 训练器,还需传递 model.block_name
)。例如,对于 GPT-2,这将看起来像:
python -u train.py ... model=blank_model model.name_or_path=gpt2-xl model.block=GPT2Block
运行 DPO
要运行 DPO,使用与 SFT 相同的命令,但传递 loss=dpo
、loss.beta=所需的BETA值
(0.1-0.5 是一个不错的起点),以及 model.archive=/path/to/checkpoint/from/sft/step-XXXX/policy.pt
。如果 SFT 成功完成,你应该还有一个来自训练结束的 /.../LATEST/policy.pt
。
在 Pythia 6.9B 上运行 DPO,有效批量大小为 64:
python -u train.py model=pythia69 datasets=[hh] loss=dpo loss.beta=0.1 model.archive=/path/to/checkpoint/from/sft/step-XXXX/policy.pt exp_name=anthropic_dpo_pythia69 gradient_accumulation_steps=2 batch_size=32 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false
注意:
eval_every
是以样本为单位计算的。
完整示例
让我们通过一个完整的示例,在 Anthropic-HH 数据集上训练 pythia 2.8B。
你可以在这里查看此示例的 wandb 输出样本(标记为 readme-example
)。
步骤 1:设置环境
首先,创建一个虚拟环境并安装依赖项。推荐使用 Python 3.8+。
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
步骤 2:运行 SFT
我们将利用 FSDP 的 bfloat16 混合精度来加速训练;我们通常能看到约 50% 的速度提升。默认情况下,SFT 将在选定数据集的混合上运行一个 epoch。数据集将按需下载并在本地缓存。
python -u train.py model=pythia28 datasets=[hh] loss=sft exp_name=anthropic_dpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16
注意:这个命令是在一台配备 4 个 80GB A100 的机器上运行的;在这种硬件上,SFT 大约需要 1 小时 30 分钟。如果你的计算资源较少,可能需要增加梯度累积步数,SFT 将花费更长时间。
你可以在这里查看 SFT 步骤的 wandb 输出样本。
步骤 3:运行 DPO
检查 wandb(如果启用,默认是启用的)或你的输出日志以找到本地运行目录。要运行 DPO,你需要最终权重的路径,它看起来像 /some/cache/dir/YOUR_USERNAME/pythia28_hh_sft_bf16_2023-06-21_16-58-17_973996/LATEST/policy.pt
。LATEST
目录包含训练结束时的最终权重集。
python -u train.py model=pythia28 datasets=[hh] loss=dpo loss.beta=0.1 exp_name=anthropic_dpo_pythia28 gradient_accumulation_steps=2 batch_size=64 eval_batch_size=32 trainer=FSDPTrainer sample_during_eval=false model.fsdp_policy_mp=bfloat16 model.archive=/path/to/archive/from/sft/LATEST/policy.pt
在 4 个 80GB A100 上,DPO 训练大约需要 2 小时 45 分钟。
你可以在这里查看 DPO 步骤的 wandb 输出样本。
自定义训练
训练选项位于 config/config.yaml
、config/model/blank_model.yaml
和 config/loss/dpo.yaml
中。有关这些选项的作用,请参阅这些文件中的注释。
你可以通过传递 model=some_model
来使用预配置的模型,其中 config/model/some_model.yaml
存在。我们已经给出了一些示例。
如果你想使用另一个模型,只需为该模型创建一个新配置(参照我们的示例;它必须是一个 .yaml
文件!),或者使用 model=blank_model
,并带上 model.name_or_path=名称或路径
,如果模型的名称/路径不同,可选择性地加上 model.tokenizer_name_or_path=分词器名称或路径
,以及 model.block_name=TRANSFORMER_BLOCK的名称
(如果你使用 FSDP)。你可能想要更改的其他选项是 dpo 损失选项,即 loss.beta
和 loss.reference_free
(参见 config/loss/dpo.yaml
)。
训练器类
我们在trainers.py
中实现了三种不同的训练器类:
-
BasicTrainer
:对于多个GPU,简单地将模型在它们之间进行分区。例如,对于两个GPU,模型的前半部分层将在GPU 0上,后半部分将在GPU 1上。这个训练器有效地增加了可用的GPU内存,但不会同时使用多个GPU进行计算(所以不会获得加速)。 -
FSDPTrainer
:使用PyTorch的完全分片数据并行(FSDP)实现来在可用的GPU之间分片每个transformer块。当每个GPU的批量大小>1时,应该比BasicTrainer
获得显著的加速。每个GPU的批量大小等于batch_size / (gradient_accumulation_steps * num_gpus)
。使用此训练器时,您可能需要在启动脚本中运行ulimit -n 64000
,然后再调用train.py
;例如,ulimit -n 64000; python train.py ...
。 -
TensorParallelTrainer
:使用PyTorch张量并行(通过这个包装器)在可用的GPU之间分片每个线性层。这个训练器是实验性的,但应该可以工作。
**警告:**对于FSDPTrainer
和特别是TensorParallelTrainer
来说,采样可能会非常慢(分别参见这个问题和这个问题)。建议对这些训练器传递sample_during_eval=false
。
我应该使用哪个训练器?
对于单GPU训练,使用BasicTrainer
。对于多GPU设置,FSDPTrainer
很可能是最佳选择,尽管尚未对这些进行基准测试。
添加新数据集
添加新的/自定义数据集很容易,通常不会花费超过10分钟左右。将您的数据集添加到preference_datasets.py
中(我们已经实现了Anthropic-HH、Stanford Human Preferences和StackExchange作为参考)。按照我们的参考数据集(在函数get_se()
、get_shp()
、get_hh()
中);您基本上需要返回一个字典,将每个提示映射到另一个包含三个值的字典:
responses: List[str]
:给出偏好的响应列表pairs: List[Tuple[int]]
:偏好对,其中每个元组的第一个值是首选响应,第二个值是非首选响应sft_target: str
:SFT期间用于此提示的响应(此响应可能是也可能不是responses
中的一个值)
一旦添加了您的数据集,例如xyz
,您可以通过在SFT或DPO训练命令中传递datasets=[xyz]
来对其进行训练。
确保您已更新preference_datasets:get_dataset()
以在传入其名称时返回您的新数据集!
在多个GPU上加快训练的技巧
当有多个GPU可用时,建议使用FSDP以加快训练速度。通常,您应该尝试在每个GPU上使用至少2的批量大小(即batch_size // (grad_accumulation_steps * N_GPUS)
至少为2)以从FSDP获得比BasicTrainer
更快的速度。实现这一点的一种方法是使用混合精度。本仓库通过FSDP实现混合精度。通过传递model.fsdp_policy_mp=bfloat16
或model.fsdp_policy_mp=float16
来启用混合精度(目前仅支持FSDPTrainer
)(仅测试了bfloat16
)。另一种减少内存使用的方法是激活检查点(或梯度检查点),可以通过activation_checkpointing=true
启用(也仅为FSDPTrainer
实现)。激活检查点并不总是能提高吞吐量,但如果您每个GPU的批量大小被限制在1,那么值得一试。
有关优化FSDP的更多信息,请参阅这篇文章。
引用DPO
如果DPO或本仓库在您的研究中有用,您可以使用以下BibTeX条目:
@inproceedings{
rafailov2023direct,
title={Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
author={Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D Manning and Stefano Ermon and Chelsea Finn},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://arxiv.org/abs/2305.18290}
}