Toolformer - Pytorch (wip)
实现 Toolformer,即MetaAI提出的能使用工具的语言模型。
感谢
-
感谢 Stability.ai 的慷慨赞助,使我们得以从事和开源尖端的人工智能研究。
-
感谢 Enrico 的初始代码提交,促成了不同工具的启动!
-
感谢ChatGPT为这个仓库中的正则表达式解析函数和API调用参数做出的贡献。我对正则表达式一窍不通,AI的帮助实在是太大了(没有任何问题,表现完美)。
安装
$ pip install toolformer-pytorch
用法
下面是一个让语言模型具备当前日期和时间意识的示例用法。
import torch
from toolformer_pytorch import Toolformer, PaLM
# 简单的日历API调用 - 返回字符串的函数
def Calendar():
import datetime
from calendar import day_name, month_name
now = datetime.datetime.now()
return f'今天是{day_name[now.weekday()]},{month_name[now.month]}{now.day}日,{now.year}年。'
# 提示如何使用上述的Calendar函数
prompt = f"""
你的任务是向文本中添加对Calendar API的调用。
API调用应帮助你获取完成文本所需的信息。
你可以通过书写 “[Calendar()]” 来调用API。
以下是一些API调用示例:
输入: 今天是一年中的第一个星期五。
输出: 今天是一年中的第一个[Calendar()]星期五。
输入: 美国总统是乔·拜登。
输出: 美国总统是[Calendar()]乔·拜登。
输入: [输入]
输出:
"""
data = [
"商店在周末从不营业,所以今天关门。",
"距离圣诞节还有30天",
"今天是星期三。"
]
# 模型 - 这里使用PaLM,但任何返回形状为 (batch, seq, num_tokens) 的 logits 的nn.Module都可以
model = PaLM(
dim = 512,
depth = 2,
heads = 8,
dim_head = 64
).cuda()
# toolformer
toolformer = Toolformer(
model = model,
model_seq_len = 256,
teach_tool_prompt = prompt,
tool_id = 'Calendar',
tool = Calendar,
finetune = True
)
# 调用这个函数会
# (1) 使用你的输入(数据)提示模型,并插入到 [输入] 标记中
# (2) 从采样的输出中筛选出正确调用API的
# (3) 使用提供的 `tool` 执行API调用
# (4) 使用专业化的过滤函数过滤掉结果(可以独立使用如下一节所示)
# (5) 在过滤结果上进行微调
filtered_stats = toolformer(data)
# 然后,一旦你看到“微调完成”消息
response = toolformer.sample_model_with_api_calls("距离下一次新年还有多少天?")
# 希望你能看到它调用日历并利用API调用的响应结果……
这篇文章的主要创新点是在定义出适合用于插入API调用的输出的适配评分。这个评分被用来过滤采样的输出,以便微调变压器模型,使其进行减少随后文本困惑度的API调用。
import torch
from toolformer_pytorch import (
Toolformer,
PaLM,
filter_tokens_with_api_response
)
# 模型
palm = PaLM(
dim = 512,
num_tokens = 20000,
depth = 2,
heads = 8,
dim_head = 64
).cuda()
# 模拟一些tokens
mock_start_pos = 512
mock_api_call_length = 10
mock_api_start_id = 19998
mock_api_stop_id = 19999
tokens = torch.randint(0, 20000, (10, 1024)).cuda()
tokens_with_api_response = torch.randint(0, 20000, (10, 1024)).cuda()
tokens_without_api_response = torch.randint(0, 20000, (10, 1024)).cuda()
tokens_with_api_response[:, mock_start_pos] = mock_api_start_id
tokens_with_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id
tokens_without_api_response[:, mock_start_pos] = mock_api_start_id
tokens_without_api_response[:, mock_start_pos + mock_api_call_length] = mock_api_stop_id
# 过滤
filtered_results = filter_tokens_with_api_response(
model = palm,
tokens = tokens,
tokens_with_api_response = tokens_with_api_response,
tokens_without_api_response = tokens_without_api_response,
filter_threshold = 1.,
api_start_token_id = mock_api_start_id,
api_end_token_id = mock_api_stop_id
)
要在语言模型生成的字符串上调用工具,请使用 invoke_tools
方法
from toolformer_pytorch import invoke_tools
def inc(i):
return i + 1
def dec(i):
return i - 1
function_registry = dict(
inc = inc,
dec = dec
)
text = '进行以下API调用:[inc(1)] 和 [dec(2)] 和 [ignored(3)]'
invoke_tools(function_registry, text)
# 进行以下API调用:[inc(1) → 2] 和 [dec(2) → 1] 和 [ignored(3)]
待办事项
- 创建定制的PaLM生成函数,可以执行外部API调用
- 允许在不同的光标索引处生成tokens
- 需要自定义API标记(在论文中是左右括号)
- 允许定制如何处理函数名、参数或执行和输出中的错误
- Toolformer最终应计算所有统计数据(如适当地采样数量、不同标准过滤的数量、评分分布以及被拒绝的数量)在最终微调之前
- 在
Toolformer
中进行端到端训练- 进行提示和启动数据
- 预过滤启动数据,然后进行API调用,再进行一轮过滤
- 跟踪所有统计数据
- 处理微调
- 数据集的交错+优化器超参数
- 挂载gpt-j
- 测试一个简单的计算器评估数据集
- 在Toolformer中添加默认回调,自动对齐文本并在过滤步骤前检查有效性 - 如果文本未正确复制,过滤步骤无效
- 确保训练了许多
Toolformer
实例的最终模型能够使用多种工具 - 从批量大小1开始逐步增加
引用
@inproceedings{Schick2023ToolformerLM,
title = {Toolformer: Language Models Can Teach Themselves to Use Tools},
author = {Timo Schick and Jane Dwivedi-Yu and Roberto Dessi and Roberta Raileanu and Maria Lomeli and Luke Zettlemoyer and Nicola Cancedda and Thomas Scialom},
year = {2023}
}
@article{Gao2022PALPL,
title = {PAL: Program-aided Language Models},
author = {Luyu Gao and Aman Madaan and Shuyan Zhou and Uri Alon and Pengfei Liu and Yiming Yang and Jamie Callan and Graham Neubig},
journal = {ArXiv},
year = {2022},
volume = {abs/2211.10435}
}
现实是这样一种东西,即使你停止相信它,它也不会消失。 – Philip K. Dick