项目介绍:Transformer Reinforcement Learning X (trlX)
项目背景
Transformer Reinforcement Learning X,简写为trlX,是为大规模语言模型进行强化学习微调而设计的分布式训练框架。它着力于在特定奖励函数或奖励标注数据集的指导下,微调大型语言模型。
主要功能
trlX支持在🤗 Hugging Face平台上的模型进行训练,使用Accelerate强化学习协议来配置该过程,允许微调模型,比如最大到20B参数的facebook/opt-6.7b
、EleutherAI/gpt-neox-20b
及google/flan-t5-xxl
。对于更大参数规模的模型,trlX利用NVIDIA NeMo的并行技术,提供高效的训练支持。
目前,trlX实现了以下两种强化学习算法:
- 邻近策略优化(Proximal Policy Optimization, PPO):支持在Accelerate和NeMo两种训练平台上进行训练。
- 隐式语言Q学习(Implicit Language Q-Learning, ILQL):同样支持在Accelerate和NeMo两种平台上。
使用说明
安装步骤
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e .
训练模型
用户可以使用奖励函数或是奖励标注数据集来进行模型训练:
-
使用奖励函数:
trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count('cats') for sample in samples])
-
使用奖励标注数据集:
trainer = trlx.train('EleutherAI/gpt-j-6B', samples=['dolphins', 'geese'], rewards=[1.0, 100.0])
-
使用提示-完成数据集:
trainer = trlx.train('gpt2', samples=[['Question: 1 + 2 Answer:', '3'], ['Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:', '(pi ** 2)/ 6']])
配置超参数
默认配置可以通过trlx.data.default_configs
进行调整,以管理内存使用和优化训练。
保存模型
训练完成后,可以将结果保存为Hugging Face的预训练语言模型格式:
trainer.save_pretrained('/path/to/output/folder/')
其他功能
- 分布式训练:可以通过🤗 Accelerate或NeMo-Megatron进行分布式训练,以处理更大的数据和模型。
- 日志记录:使用标准Python日志库进行训练信息的记录,用户可以控制日志的详细程度。
贡献及引用
关于如何为该项目做出贡献的信息和项目引用方式,请参阅项目文档及贡献指南。
致谢
感谢Leandro von Werra贡献了trl库,该项目的灵感来源之一。