提示微调
这是用于重现 EMNLP 2021 论文《The Power of Scale for Parameter-Efficient Prompt Tuning》(Lester et al., 2021)中的实验代码。
这些模型建立在T5X之上,T5X定义了模型和训练循环;Flaxformer定义了实际的模型计算;Flax定义了低级模型层;而Jax提供了实际的执行。我们实现的详细信息可以在这里找到。
目录
安装
- 按照 T5X 安装说明中的前三步创建一个云 TPU 虚拟机(VM)。还要按照步骤五创建一个 Google Cloud Storage(GCS)桶。我们将使用格式为
gs://{bucket-name}/path/to/item/in/bucket
的 URI 读写该桶中的数据。我们将在此存储缓存的数据集以及模型检查点和结果。为了便于参考,以下是一些与 TPU 虚拟机交互最常用的云命令:
# 创建一个云 TPU 虚拟机
$ gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \
--zone ${ZONE} \
--accelerator-type v3-8 \
--version v2-alpha
# SSH 登录到云 TPU 虚拟机
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE}
# 删除云 TPU 虚拟机
$ gcloud alpha compute tpus tpu-vm delete ${TPU_NAME} --zone ${ZONE}
- 现在应该在 TPU 虚拟机实例的命令行中。克隆提示微调库。
git clone --branch=main https://github.com/google-research/prompt-tuning
cd prompt-tuning
- 安装提示微调库。
python3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
如果遇到 pip 尝试安装更早的依赖版本(例如 TensorFlow),直到尝试安装版本 0.0.0
而失败的错误,请尝试将 --use-deprecated=legacy-resolver
添加到安装命令中。此错误与依赖项之间的所需版本有关,这种行为通常被称为回溯。如果使用该标志,可能会安装不兼容的库版本,您应注意安装命令输出中关于不匹配的警告。
注意: 如果计划修改提示微调内部并需要可编辑的安装(因此在运行训练时使用已克隆代码中的更改),请使用 -e
标志运行 pip
,如果在安装过程中收到错误,可能需要删除 pyproject.toml
文件。
要运行测试,请使用 [test]
选项安装包(python3 -m pip install .[test] ...
),然后从克隆库的根目录运行 python3 -m pytest
。
训练提示
训练提示类似于 使用 T5X 微调模型; 主要区别在于我们有一套自己的提示微调配置文件可供使用。
我们提供了一个演示脚本 (prompt_tuning/scripts/sst2-demo.sh
),其中包含训练提示所需的所有部分。您可以将其作为起点,或设置指向您的 Google Cloud Storage 桶的路径来直接运行此脚本。
./prompt-tuning/prompt_tuning/scripts/sst2-demo.sh
为了帮助加快迭代速度,我们倾向于在命令行中指定更多的选项,而不是将所有配置捆绑到一个 gin 文件中。以下是一些值得注意的选项:
--gin_search_paths
:一个以逗号分隔的目录列表,用作 gin 文件的路径前缀。我们可以使用prompt_tuning.scripts.find_module ${module}
来找到包含配置的库的安装位置。--gin_file
:要加载的 gin 文件。我们倾向于使用以安装库开始的相对路径,例如prompt_tuning/configs/models/t5_1_1_base_prompt.gin
而不是models/t5_1_1_base_prompt.gin
,以避免混淆。多次使用该标志可以指定将要合并在一起的多个 gin 文件。任何在多个文件中设置的配置选项将使用列表中最后一个文件中的值。--gin.{PARAM}={VALUE}
:此通用覆盖标志将PARAM
设置为VALUE
。这可以方便地设置配置选项,而不需要它们实际作为命令行参数。例如,--gin.utils.SaveCheckpointConfig.keep=20
将保存最后 20 个检查点。
在 Pod 切片上训练提示
随着模型规模的增大,例如 xl 和 xxl,它们无法装载在单个 TPU 虚拟机的 8 个 TPU 上。在这种情况下,我们需要一个 TPU pod 的切片(有关 TPU 架构和可用配置的详细信息,请参阅此处)。在单个 TPU 虚拟机上训练提示和在 Pod 切片上训练提示之间的主要区别在于我们现在有多个 TPU 虚拟机,并且将在各 VM 上运行相同的 SPMD JAX 程序。本页面有更多关于multi-host JAX 程序的资料。本指南 介绍了在 TPU Pod 切片上运行 JAX 程序的简要说明,但我们将在这里概述主要点。
- 创建一个 TPU Pod 切片。 此页面 列出了哪些加速器类型在哪些区域可用。这与创建 TPU 虚拟机相同,只是我们请求的是 32 个 TPU 而不是 8 个。
$ gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \
--zone ${ZONE} \
--accelerator-type v3-32 \
--version v2-alpha
- 安装提示微调库。因为我们现在有 4 个 TPU 虚拟机,且每个虚拟机有 8 个 TPU,我们希望避免直接 SSH 到每个 VM,因为我们需要对每个主机这样做。因此,Google Cloud SSH 命令允许我们使用
--command=
标志指定要运行的命令,并用--worker=all
在所有 VMs(称为工人)上运行该命令。
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--zone ${ZONE} \
--worker=all \
--command="git clone --branch=main https://github.com/google-research/prompt-tuning && cd prompt-tuning && "
python3 -m pip install . -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
-
编写训练提示的脚本。我们提供了一个示例脚本(
/prompt_tuning/scripts/sst2-xxl-demo.sh
),该脚本训练一个提示来解决 SST2 数据集,使用 T5 1.1 lm100k XXL。这可以作为您的起点,或只需填写路径到您的 Google Cloud Storage 桶,以指定保存结果的位置(MODEL_DIR
)和缓存 TFDS 数据的位置(TFDS_DATA_DIR
),或将这些路径设置为环境变量。 -
将您的训练脚本复制到每个 worker。如果这是您第一次运行
scp
,可能会遇到错误,请运行错误消息中的ssh-add /.../.ssh/google_compute_engine
命令并再试一次。
$ gcloud alpha compute tpus tpu-vm scp sst2-xxl-demo.sh ${TPU_NAME}: \
--zone=${ZONE}
--worker=all
- 执行您的训练脚本。
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--zone ${ZONE} \
--worker=all \
--command="./sst2-xxl-demo.sh"
如果某个 worker 在训练过程中出错,其他 worker 上将有正在使用 TPU 的进程。这将阻止您重新启动作业,直到这些进程终止并释放 TPU。以下命令应结束所有这些进程。您可能会看到来自最初出错 worker 的 kill
命令手册页。
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--zone ${ZONE} \
--worker=all \
--command="sudo lsof -t /dev/accel0 | xargs kill -9"
自定义依赖项
要使用自定义组件(如您自己的数据集)训练提示,请遵循 T5X 自定义组件说明
如果将代码打包为一个可通过 pip 安装的 python 包,您不会受限于单个目录,可以使用 python3 -m prompt_tuning.scripts.find_module {your_module}
来帮助设置
gin_search_paths
,以便找到库中捆绑的 gin 配置。
注意: 如果计划将 gin 配置捆绑到可安装的包中,请确保包含配置文件的目录具有 __init__.py
,因为 gin 需要文件在 python 包中。
如果您的自定义组件部分是 gin 可配置的,它们需要在您的 gin 文件中显式导入;如果它们在解析 gin 文件后被导入,将导致错误。如果您的依赖项中没有 gin 可配置项,可以通过传递 --gin.MIXTURE_OR_TASK_MODULE="'path.to.your.module'
来避免编写 gin 文件。这样将自动导入您的模块,适用于仅交换数据集的情况。
使用提示进行推理
我们建议使用提示进行推理的方法是加载用来初始化模型的初始检查点,并从文件中加载提示。正如这一节关于部分加载中所解释的那样,T5X 支持在初始化一些模型参数时加载其他参数。我们将其与 from_array
提示初始化器结合使用,从原始检查点重新加载冻结参数并从一个文件中加载提示。
configs/runs/prompt_eval.gin
设置了此配置,您只需提供一个 PROMPT_FILE
。如果您的模型是使用任何 prompts/
配置文件训练的,可以将它们从评估脚本的参数中删除。
包含的 sst2-demo-eval.sh
脚本显示了以这种方式进行评估的示例。只需设置 EVAL_DIR
和 TFDS_DATA_DIR
环境变量以存储评估输出和 tensorflow 数据集缓存的路径即可。
在 T5X 中,评估脚本假定您的数据集有标签,并输出数据集的度量函数的最终结果。推理脚本不需要标签,相反输出模型的预测。我们提供了类似的 prompt_infer.gin
文件,以便与推理脚本一起使用。
如果要使用提示微调训练运行生成的 t5x 检查点进行推理或评估,可以直接使用 T5X 的 (eval|infer).gin
配置。需要更新 utils.RestoreChekcpointConfig
。应将 path
设置为新检查点,assignment_map=()
和 fallback_to_scratch=False
。
模型配置
所有模型、训练、评估、保存、恢复等配置均通过 gin 完成。请参阅 gin-config 仓库 了解 gin 的一般介绍和 该入门手册
我们遵循 T5X 的配置布局:
runs/
:包含模型实际训练的配置。此处配置数据集和评估等内容。architectures/
:包含模型工作方式的配置。此处配置编码器-解码器与仅解码器、嵌入共享等。models/
:包含设置特定模型参数的配置,如层数或嵌入表的大小。还配置了 T5X 模型包装器。models/decoding/
:包含用于在推理期间更改模型生成文本方式的易用配置,包括光束搜索和核采样的配置。models/sizes/
:包含创建不同大小模型的各种设置,这些设置与默认版本结合以创建指定大小版本,例如,t5_1_1_prompt.gin
和sizes/large.gin
结合创建 T5 1.1 Large 模型。一些常见的组合已经作为 gin 文件提供了正确的包含(例如上述的t5_1_1_large_prompt.gin
)。注意: 这些大小文件需要在主模型文件之后。
prompts/
:我们的额外目录包含设置PROMPT
gin 变量的配置文件,允许根据添加的提示文件通过--gin_file
参数轻松切换提示初始化(需要在models/
gin 文件之后)。
gin 配置文件顺序
在命令行中指定 --gin_file
参数时,顺序很重要。gin 文件的指定顺序大致为:
models/*.gin
prompts/*.gin
models/sizes/*.gin
models/decoding/*.gin
runs/*.gin
必填字段
T5X 有一些必填字段,如 MIXTURE_OR_TASK_NAME
或 TASK_FEATURE_LENGTHS
。我们增加了两个:
PROMPT_LENGTH
:我们使用的提示的长度,这在几个不同的地方使用,因此我们要求它作为 gin 宏,以便在多个地方引用并确保值一致。PROMPT
:这是用于 FlaxformerPromptX
子类中的实际提示模块配置。
注意: 提示微调当前不支持打包样例。这意味着我们的最大目标长度只需足够容纳每个样例的目标。因此,TASK_FEATURE_LENGTHS
映射中的 targets
键可以短得多,例如对于许多 SuperGLUE (Wang et al., 2019) 任务来说,大约为 4,而 P5X 的默认值是 62。
提示初始化
提示参数的初始化有几种选项。我们支持论文第 3.2 节中的各种方法,以及从文件初始化。这允许根据从 MNLI 学习到的提示来训练 BoolQ。
所有初始化函数遵循 Flax 初始化器 API,即作为初始化函数闭包的参数化函数。实际初始化函数总是具有如下签名:
def initializer(rng: Array, shape: Sequence[int]) -> Array:
...
我们在 configs/prompts
目录中提供了每种初始化方案的 gin 配置文件。可以通过 --gin_file=path/to/configs/prompts/scheme.gin
包含 gin 文件来使用这些文件。该文件需要在主模型文件之后,否则默认方法(随机均匀)将覆盖所选择的方法。一些初始化方法需要通过覆盖标志或在 gin 文件中设置额外的 gin 值。
随机均匀
一种标准的随机初始化,类似于为嵌入初始化所用的方法。这是默认方法,不需要 gin 文件。可以通过覆盖 prompt_init/linen.initializers.uniform.scale=N
参数来调整随机值的比例。
采样词汇表
使用 from_sample_of_embeddings
初始化器,针对每个提示位置采样一个令牌嵌入用于初始化。可以使用 prompt_init/prompts.from_samples_of_embeddings.population_size
参数限制采样前 n
个嵌入。
这可通过
--gin_file=prompt_tuning/configs/prompts/from_sampled_vocab.gin
使用。此方法使用从初始模型检查点提取的嵌入表。您还可以提供自己的嵌入文件,以
--gin_file=prompt_tuning/configs/prompts/from_sampled_vocab_numpy.gin
使用。此方法要求您提供一个 EMBEDDING_FILE
值,该值是模型嵌入表的 numpy 数组。可以使用
[prompt_tuning.scripts.extract_variable](https://github.com/google-research/prompt-tuning/tree
我们支持通过 from_embedded_list
初始化将提示时间步数初始化为类标签(即 verbalizers)的嵌入。用户提供需要使用的单词列表(类标签)。每个词都由提供的词汇表进行分词;由提供的词汇表表嵌入;如果需要,跨子分词聚合;并用于初始化提示时间步长。如果提供的标记未覆盖完整的提示长度,则使用提供的备用初始化器初始化缺失的标记。
我们可以匹配论文中的内容,通过将这种初始化与上面提到的初始化组合,未填充的提示标记通过从嵌入表中采样来填充。它可以与
--gin_file=prompt_tuning/configs/prompts/from_class_labels.gin
一起使用。这需要设置 CLASS_LABELS
,这是您希望嵌入作为提示初始化的一组单词。您还可以通过
--gin_file=prompt_tuning/configs/prompts/from_class_labels_numpy.gin
提供自己的嵌入文件(与上述相同)。这还需要设置 EMBEDDING_FILE
。
从字符串
我们也支持使用某些字符串的嵌入初始化提示,通常用于从离散提示或任务描述开始。这使用 from_embedded_string
初始化器。字符串由提供的词汇表进行分词,每个标记在提供的嵌入表中查找,并使用字符串的结果嵌入表示进行提示初始化。如果提供的标记未覆盖完整的提示长度,则使用提供的备用初始化器初始化缺失的标记。
注意: 词汇表只是将字符串转换为 id 序列,您需要确保字符串与您的 SeqIO 任务的任何文本格式化结果(标点符号周围的空格等)匹配。
从文件
您还可以使用 from_array
初始化器从文件加载提示,以实现任务间的转移。这可以通过
--gin_file=prompt_tuning/configs/prompts/from_file.gin
完成。这需要设置 PROMPT_FILE
,路径指向包含要加载的提示的 Numpy 文件。训练时默认会发出 Numpy 版本的提示,但也可以通过上述提到的脚本提取提示。
发布的模型检查点
我们已经发布了 T5 1.1 检查点的 T5X 本地检查点,这些检查点已进行了 100K 步的语言模型适应。
- t5_1_1_lm100k_small(约 7700 万参数): gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_small/checkpoint_1100000
- t5_1_1_lm100k_base(约 2.5 亿参数): gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_base/checkpoint_1100000
- t5_1_1_lm100k_large(约 8 亿参数): gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_large/checkpoint_1100000
- t5_1_1_lm100k_xl(约 30 亿参数): gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xl/checkpoint_1100000
- t5_1_1_lm100k_xxl(约 110 亿参数): gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xxl/checkpoint_1100000
这些是从公开的 Mesh TensorFlow 检查点转换而来的。
发布的提示
我们已经在各种任务上发布了预训练的提示,并计划随着时间的推移增加这些提示。
提示可以在
pretrained_prompts
目录中找到。从那里,每个子目录按训练的模型对提示进行分组。引用与库捆绑在一起的这些提示的最简单方法是:
--PROMPT_FILE=`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/pretrained_prompts/{MODEL_SIZE}/{PROMPT}.npy
由于并行计算的固有随机性,有一些设置需要在训练和评估之间匹配以获得完全相同的结果。每个模型子目录中有一个 README.md
指定了这些设置应该是什么。最重要的设置匹配项是批量大小、TPU 拓扑和模型并行分区。表格包含了如果你在 t5x.eval
中使用这些提示时应该预期看到的分数。
额外资源
这是关于提示调整的其他资源的集合。
如何引用
如果你使用这项工作作为起点,请引用
@inproceedings{lester-etal-2021-power,
title = "The Power of Scale for Parameter-Efficient Prompt Tuning",
author = "Lester, Brian and
Al-Rfou, Rami and
Constant, Noah",
booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
month = nov,
year = "2021",
address = "Online and Punta Cana, Dominican Republic",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2021.emnlp-main.243",
doi = "10.18653/v1/2021.emnlp-main.243",
pages = "3045--3059",
}
注意
这不是官方支持的 Google 产品。