与人类相关的损失函数 (HALOs) :innocent:
此仓库允许您设计新的与人类相关的损失函数 (HALOs),以大规模地将LLMs与离线人类反馈对齐(详细阅读我们的技术报告或完整论文)。 该仓库已被用于创建Archangel,这是有史以来最大的与人类反馈对齐的LLM套件,并已在从1B到30B的规模上进行测试。
此仓库借鉴了写得非常好的DPO仓库,并保留了许多原始设计选择。 我们引入的一些关键更改包括:
- 使数据加载更具模块化,以便您可以轻松编写自己的数据加载器
- 使训练器更具模块化,因此每个HALO都有自己的训练器子类
- 添加代码以使用GPT-4进行开放式评估
- 支持超越SFT和DPO的损失(包括KTO,PPO(离线,非政策变体)和SLiC)
要首先对一个模型进行SFT,请运行类似以下的命令
python train.py loss=sft model=llama7b datasets=[shp,hh,oasst] exp_name=llama7b_sft mode=train ++cache_dir=/data/models
这会将模型保存到/data/models/llama7b_sft/LATEST/policy.pt
。然后要使用KTO对模型进行对齐,请运行类似以下的命令
python train.py loss=kto model=llama7b datasets=[shp,hh,oasst] exp_name=llama7b_kto mode=train ++cache_dir=/data/models ++model.load_from=llama7b_sft/LATEST/policy.pt
这会将模型保存到/data/models/llama7b_kto/LATEST/policy.pt
。
快速开始
假设我们要实现一个名为Kahneman-Tversky优化(KTO)的新HALO。 此仓库已经根据我们的报告中的详细信息实现了这一点,但我们假装并没有实现它。 我们该怎么做?
-
首先,创建并激活conda环境。
conda env create -f environment.yml
conda activate halos
如果您无法创建conda环境,或在安装过程中遇到问题,请尝试执行以下步骤:
conda create -n halos3 python=3.10.12 pip3 install numpy==1.24.3 ninja==1.11.1.1 packaging==23.1 conda install pytorch==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia pip3 install flash-attn==2.3.3 pip3 install transformers==4.35.2 datasets hydra-core==1.3.2 wandb==0.15.3 openai==1.6.1 accelerate==0.21.0 tensor-parallel==1.2.4
-
确定是否需要一个新的数据集。如果您有一个名为
foo
的数据集,请在dataloader.py
中添加一个名为get_foo
的函数,该函数将返回一个Dataset
实例。该函数应具有以下签名,其中前缀和后缀确定数据集的格式(参见config.yaml
),split
应为train
或test
:def get_foo(split: str, human_prefix: str, human_suffix: str, assistant_prefix: str, assistant_suffix: str) -> Dataset:
-
确定是否需要一个新的数据加载器。KTO不使用偏好对,只是知道输出是可取的还是不可取的。 这意味着我们使用
dataloader.UnpairedPreferenceDataLoader
。然而,该数据加载器假设您正在处理的原始数据集包含偏好对,如 Anthropic HH 或 SHP。 如果您需要自定义数据加载器,可以通过扩展基本的DataLoader
类在同一个Python文件中实现它。 -
在
trainers.py
中编写一个训练器。根据其是否使用偏好对,应该子类化UnpairedPreferenceTrainer
或PairedPreferenceTrainer
。 如果需要高度自定义的行为,可以直接子类化BasicTrainer
。我们可以如下实现KTO的简单版本(注意,这不同于
KTOTrainer
中真正的KTO版本,它不假设每个批次中均存在选择的和拒绝的示例)。要创建SimpleKTOTrainer,我们只需子类化
trainers.UnpairedPreferenceTrainer
为trainers.SimpleKTOTrainer
并重写损失函数定义。KTO有一个超参数beta,可以通过self.config.loss.beta
访问:class SimpleKTOTrainer(UnpairedPreferenceTrainer): """一个简单的KTO版本,用于介绍HALOs仓库。""" def loss(self, policy_chosen_logps: torch.FloatTensor, policy_rejected_logps: torch.FloatTensor, reference_chosen_logps: torch.FloatTensor, reference_rejected_logps: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """为一批策略和参考模型日志概率计算Kahneman-Tversky损失。 对于n/2选择示例和n/2拒绝示例(属于n个不同输入)的每一批,按如下方式计算损失。 如果生成y ~ p_chosen,其中x' ~ 是具有拒绝生成的示例,我们有'选择'损失: L(x, y) := 1 - sigmoid(beta * (log p_policy(y|x) - log p_reference(y|x) - KL(p_policy(y_rejected|x') || p_reference(y_rejected|x'))) 如果生成y ~ p_rejected,其中x' ~ 是具有选择生成的示例,我们有'拒绝'损失: L(x, y) := 1 - sigmoid(beta * KL(p_policy(y_chosen|x') || p_reference(y_chosen|x')) - [log p_policy(y|x) - log p_reference(y|x)]) """ # 在此处实现您的代码 return losses, chosen_rewards, rejected_rewards
-
在config/loss文件夹中添加一个文件,指定损失的详细信息:
name: kto-simple beta: 0.1 # 简单KTO的温度参数;值越低,表示对参考模型的关注越少 trainer: SimpleKTOTrainer # 在trainers.py中实现 dataloader: UnpairedPreferenceDataLoader # 已经在dataloaders.py中存在 use_reference_model: true # 因为损失定义包括参考模型
-
现在我们可以开始训练模型了!让我们在SHP、Anthropic HH和Open Assistant数据集上训练一个Llama-7B模型。 由于Llama-7B的相应条目在config/model/llama7b.yaml中,我们运行带有Hydra的命令:
python train.py loss=kto-simple model=llama7b datasets=[shp,hh,oasst] exp_name=kto-simple_llama7b mode=train ++cache_dir=/data/models
这将从头开始对齐一个Llama-7B模型。如果我们想要对齐已经用HALOs仓库微调过的模型, 可以在命令末尾添加类似
++model.load_from=/data/models/sft_llama7b/LATEST/policy.pt
。就这样!您的模型将被保存到
/data/models/kto-simple_llama7b/LATEST/policy.pt
。 -
让我们从新训练的模型中采样一些生成。采样配置位于
config/config.yaml
或models/
下。 我们可以使用以下命令在批量为32的情况下,从新训练的模型中采样512个生成,这将在samples/{config.exp_name}.json
下创建一个JSON文件。python eval.py --config-path=/data/models/kto-simple_llama7b/config.yaml ++mode=sample ++n_samples=512 ++model.eval_batch_size=32 ++samples_dir=samples/
-
在设置
OPENAI_API_KEY
后,我们可以运行以下命令使用GPT-4评估我们的对齐模型,该命令将比较对齐模型的生成与数据中的人类选择响应:python compare.py -f samples/kto-simple_llama7b.json -mc 512 -bk chosen -ck policy -r result.jsonl
常见问题
-
您支持多节点训练吗?
不,目前该仓库仅支持单节点训练。多节点训练将在未来某个时候添加。 每个Archangel套件中的模型都是使用单个节点上的8个A100 GPU训练的。
-
如何保存中间检查点?
在config/config.yaml中将intermediate_checkpoints设置为true,或者在命令行中使用++config.intermediate_checkpoints=true。 每隔config.eval_every步,实验目录($cache_dir/$exp_name)中将保存一个检查点。
-
在哪里可以找到所有的Archangel模型?
它们都在Huggingface Hub上: <SOURCE_TEXT> | 模型 | PPO | DPO | KTO | SFT | SLIC | SFT+PPO | SFT+DPO | SFT+KTO | CSFT | SFT+CSFT | | ------------- |:-------------:|-------------:|-------------:|-------------:|-------------:|-------------:|-------------:|-------------:| -------------:|-------------:| | pythia1-4b | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 |
| pythia2-8b | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 |
| pythia6-9b | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 |
| pythia12-0b | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 |
| llama7b | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 |
| llama13b | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 |
| llama30b | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 | 权重 |
引用
如果您发现此仓库或技术论文在您的研究中有用,请随意引用我们的工作:
@techreport{ethayarajh2023halos,
author = {Ethayarajh, Kawin and Xu, Winnie, and Jurafsky, Dan and Kiela, Douwe},
title = {Human-Aware Loss Functions (HALOs)},
institution = {Contextual AI},
note = {https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf},
year = {2023},
}
</SOURCE_TEXT>