Project Icon

motif

利用大语言模型偏好生成奖励函数的强化学习框架

Motif是一个新型强化学习框架,通过大型语言模型的偏好生成奖励函数。它分为数据集注释、奖励函数训练和强化学习三个阶段。在NetHack游戏中,Motif展现出优秀性能,生成符合人类直觉的行为,并可通过提示词灵活调整。这种方法为开发智能AI代理提供了新的研究方向,具有良好的扩展潜力。

概述

本仓库包含了Motif的PyTorch代码,用于在NetHack游戏中训练AI代理,其奖励函数源自大语言模型的偏好。

Motif: 来自人工智能反馈的内在动机

作者:Martin Klissarov* & Pierluca D'Oro*, Shagun Sodhani, Roberta Raileanu, Pierre-Luc Bacon, Pascal Vincent, Amy ZhangMikael Henaff

motif

Motif从NetHack游戏交互数据集中获取带有说明的观察对,引出大语言模型(LLM)对这些对的偏好。它自动地将LLM的常识提炼为奖励函数,用于通过强化学习训练代理。

为便于比较,我们在pickle文件motif_results.pkl中提供了训练曲线,该文件包含以任务为键的字典。对于每个任务,我们提供了Motif和基线方法在多个种子下的时间步和平均回报列表。

如下图所示,Motif包含三个阶段:

  1. 数据集标注:利用LLM对带说明的观察对的偏好创建标注后的数据集;
  2. 奖励训练:使用标注后的数据集对和LLM的偏好作为监督信号训练奖励函数;
  3. 强化学习训练:使用Motif的奖励函数训练代理。

我们通过提供必要的数据集、命令和原始结果,详细说明了每个阶段,以便重现论文中的实验。

motif

我们通过NetHack学习环境评估Motif在具有挑战性、开放式和程序生成的NetHack游戏中的表现。我们研究了Motif如何主要生成符合人类直觉的行为,这些行为可以通过提示修改轻松调整,以及其扩展性。

motif

motif

要安装整个流程所需的依赖项,只需运行pip install -r requirements.txt

使用Llama 2进行数据集标注

在第一阶段,我们使用一个带有说明(即游戏中的消息)的观察对数据集,这些观察对是由经过强化学习训练以最大化游戏分数的代理收集的。 我们在本仓库中提供了该数据集。 我们将不同部分存储在motif_dataset_zipped目录中,可以使用以下命令解压缩。

cat motif_dataset_zipped/motif_dataset_part_* > motif_dataset.zip; unzip motif_dataset.zip; rm motif_dataset.zip

我们提供的数据集包含了Llama 2模型给出的一组偏好,存储在preference/目录中,使用了论文中描述的不同提示。 包含标注的.npy文件名遵循模板llama{size}b_msg_{instruction}_{version},其中size是来自集合{7,13,70}的LLM大小,instruction是引入给LLM的提示中的指令,来自集合{defaultgoal, zeroknowledge, combat, gold, stairs}version是要使用的提示模板版本,来自集合{default, reworded}。 以下是可用标注的摘要:

标注论文中的用例
llama70b_msg_defaultgoal_default主要实验
llama70b_msg_combat_default引导向_怪物杀手_行为
llama70b_msg_gold_default引导向_黄金收集者_行为
llama70b_msg_stairs_default引导向_下降者_行为
llama7b_msg_defaultgoal_default扩展实验
llama13b_msg_defaultgoal_default扩展实验
llama70b_msg_zeroknowledge_default零知识提示实验
llama70b_msg_defaultgoal_reworded提示重写实验

为创建标注,我们使用vLLMLlama 2的聊天版本。如果你想生成自己的标注或重现我们的标注过程,请确保按照官方说明下载模型(获取模型权重可能需要几天时间)。

标注脚本假设数据集将使用n-annotation-chunks参数分成不同的块进行标注。这允许根据可用资源进行并行处理,并且对重启/抢占具有鲁棒性。要使用单个块运行(即处理整个数据集),并使用默认提示模板和任务规范进行标注,请运行以下命令。

python -m scripts.annotate_pairs_dataset --directory motif_dataset \
                                 --prompt-version default --goal-key defaultgoal \
                                 --n-annotation-chunks 1 --chunk-number 0 \
                                 --llm-size 70 --num-gpus 8

请注意,默认行为是通过将标注附加到指定配置的文件来恢复标注过程,除非通过--ignore-existing标志另有说明。也可以使用--custom-annotator-string标志手动选择为标注创建的'.npy'文件的名称。可以使用单个32GB内存的GPU进行--llm-size 7--llm-size 13的标注。 你可以使用8个GPU的节点进行--llm-size 70的标注。这里我们提供使用NVIDIA V100s 32G GPU对100k对数据集进行标注的粗略时间估计,这应该能大致重现我们的大多数结果(我们的结果是使用500k对获得的)。

模型所需标注资源
Llama 2 7b约32 GPU小时
Llama 2 13b约40 GPU小时
Llama 2 70b约72 GPU小时

奖励训练

在第二阶段,我们通过交叉熵将大语言模型的偏好提炼为奖励函数。要使用默认超参数启动奖励训练,请使用以下命令。

python -m scripts.train_reward  --batch_size 1024 --num_workers 40  \
        --reward_lr 1e-5 --num_epochs 10 --seed 777 \
        --dataset_dir motif_dataset --annotator llama70b_msg_defaultgoal_default \
        --experiment standard_reward --train_dir train_dir/reward_saving_dir

奖励函数将通过位于--dataset_dir中的annotator的标注进行训练。然后,结果函数将保存在train_dir下的--experiment子文件夹中。

强化学习训练

最后,我们通过强化学习使用得到的奖励函数训练一个智能体。要在NetHackScore-v1任务上使用默认超参数训练智能体,结合内在和外在奖励进行实验,可以使用以下命令。

python -m scripts.main --algo APPO --env nle_fixed_eat_action --num_workers 24 \
        --num_envs_per_worker 20 --batch_size 4096 --reward_scale 0.1 --obs_scale 255.0 \
        --train_for_env_steps 2_000_000_000 --save_every_steps 10_000_000 \       
        --keep_checkpoints 5 --stats_avg 1000 --seed 777  --reward_dir train_dir/reward_saving_dir/standard_reward/ \
        --experiment standard_motif --train_dir train_dir/rl_saving_dir \
        --extrinsic_reward 0.1 --llm_reward 0.1 --reward_encoder nle_torchbeast_encoder \
        --root_env NetHackScore-v1 --beta_count_exponent 3 --eps_threshold_quantile 0.5

要更改任务,只需修改--root_env参数。下表明确列出了与论文中呈现的实验相匹配所需的值。NetHackScore-v1任务的extrinsic_reward值为0.1,而其他所有任务的值为10.0,以激励智能体达到目标。

环境root_env
得分NetHackScore-v1
楼梯NetHackStaircase-v1
楼梯(第3层)NetHackStaircaseLvl3-v1
楼梯(第4层)NetHackStaircaseLvl4-v1
神谕NetHackOracle-v1
神谕-清醒NetHackOracleSober-v1

此外,如果你只想使用来自大语言模型的内在奖励而不使用环境奖励来训练智能体,只需设置--extrinsic_reward 0.0。在仅使用内在奖励的实验中,我们只在智能体死亡时终止回合,而不是在智能体达到目标时终止。这些修改后的环境列在下表中。

环境root_env
楼梯(第3层)- 仅内在奖励NetHackStaircaseLvl3Continual-v1
楼梯(第4层)- 仅内在奖励NetHackStaircaseLvl4Continual-v1

可视化你的强化学习智能体

我们还提供了一个脚本来可视化你训练的强化学习智能体。这可以提供对其行为的重要洞察,同时还会生成每个回合的顶级消息,有助于理解它试图优化的目标。你只需运行以下命令即可。

python -m scripts.visualize --train_dir train_dir/rl_saving_dir --experiment standard_motif

引用

如果你在我们的工作基础上进行研究或发现它有用,请使用以下bibtex引用。

@article{klissarovdoro2023motif,
    title={Motif: Intrinsic Motivation From Artificial Intelligence Feedback},
    author={Klissarov, Martin and D'Oro, Pierluca and Sodhani, Shagun and Raileanu, Roberta and Bacon, Pierre-Luc and Vincent, Pascal and Zhang, Amy and Henaff, Mikael},
    year={2023},
    month={9},
    journal={arXiv preprint arXiv:2310.00166}
}

许可证

Motif的大部分内容采用CC-BY-NC许可,但项目的某些部分采用单独的许可条款:sample-factory采用MIT许可。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号