Landmark 注意力机制
本仓库包含了我们论文中所描述的 landmark 注意力机制的实现:
Landmark 注意力:Transformer 的随机访问无限上下文长度
Amirkeivan Mohtashami, Martin Jaggi
NeurIPS 2023: https://arxiv.org/abs/2305.16300
仓库结构
该仓库包含三个代码库,位于以下目录:
lm_benchmark
: 该目录包含用于在 PG19 和 arXiv Math 数据集上进行语言建模的代码。llama_legacy
: 该目录包含用于获得论文中报告的 LLaMA 微调结果的代码。此目录中的代码已冻结以允许复现结果。因此,除非试图精确复制我们的结果,否则我们建议使用llama
目录下的代码。llama
: 该目录包含 landmark 注意力机制的当前实现。该目录包括 landmark 注意力的高级实现和与 Flash Attention 结合的 Triton 实现。作为示例,该目录包含将实现应用于 LLaMA 模型的代码。
注意:在项目开发过程中,我们决定更新某些组件的名称。然而,由于这个决定是在项目后期做出的,你可能会在代码中遇到旧名称的引用(例如 mem
而不是 landmark
)。我们正在努力解决这个问题。
语言建模基准
训练
对于训练,landmark 标记在数据准备期间添加。以下命令是在 PG19 上训练模型的示例,每 50 个标记添加一个 landmark 标记:
python main.py \
--config_format rotary \
--model landmark \
--n_embd 1024 \
--n_head 8 \
--n_layer 12 \
--batch_size 16 \
--sequence_length 512 \
--acc_steps 8 \
--wandb_project memory-llm \
--dataset pg19 \
--iterations 240000 \
--dropout 0.0 \
--positional_encoder rotary \
--softmax_func mem_opt \
--mem_freq 50 \
--wandb \
--save_checkpoint_freq 20000
要在多 GPU 上运行,请使用 torchrun(例如 torchrun --nproc_per_node=4
)并将 --distributed_backend nccl
传递给 main.py
脚本。我们建议首先在单个 GPU 上运行脚本直到训练开始,然后再切换到多 GPU 设置。这是因为第一个节点将必须执行数据初始化,这可能需要很长时间,导致多 GPU 设置中的同步超时。然而,一旦执行了初始化,结果将存储在磁盘上,因此下次运行将会很快。
在运行训练脚本之前,你需要初始化数据集。有关说明,请使用位于 data/
内相应数据集文件夹中的 prepare.py
脚本。
推理
该代码支持各种设置下的推理。要执行标准评估,请禁用缓存并使用与评估长度(由 --eval_seq_length
指定)相同的块大小(使用 --mid_length
标志指定)。使用 mem_cache
时可以使用 landmark。脚本 eval_cmd_generator.py
可用于生成包含执行对应于论文表 1 和表 2 的评估的命令的 bash 脚本。需要在脚本内更新输出模型的路径。
LLaMA 微调
用于微调 LLaMA 和测试最终模型的代码作为独立项目在子目录"llama"中提供。运行微调的示例(从子目录内)如下:
torchrun --nproc_per_node=8 train.py \
--model_name_or_path /llama_weights/7B_hf/ \
--bf16 True \
--output_dir /llama-redpajama-mem-15000-with-mem/ \
--cache_dir /hf-cache/ \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 2000 \
--save_total_limit 2 \
--learning_rate 2e-5 \
--weight_decay 0.1 \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--max_steps 15000
在上面的示例中,LLaMA 权重(转换为 huggingface 格式)应该在 /llama_weights/7B_hf/
中。
微调权重
我们已经发布了原始 LLaMA 7B 和在 RedPajama 数据集上使用 landmark 注意力微调 15000 步的相同模型之间的权重差异 在这里。你可以使用 weight_diff.py
脚本恢复权重:
python weight_diff.py recover --path_raw <path_to_original_llama7b_weights> --path_diff <path_to_weight_diff> --path_tuned <path_to_store_recovered_weights>
有关如何使用 landmark 进行推理的示例,请查看 run_test.py
。
Triton 实现
我们添加了我们的方法和 Flash Attention 组合的 Triton 实现,这显著降低了内存使用并提高了性能。使用这个实现,我们训练了上下文长度为 2048 的 LLaMA 7B(而不是 512)。此外,通过应用以下更改,可以将 landmark 注意力添加到任何模型中:
- 以块大小的规则间隔将 landmark 标记添加到输入中。
- (可选)创建一个布尔掩码,指示哪些标记是 landmark。可以将掩码传递给 landmark 注意力函数,以确保正确放置 landmark。为获得最高速度,可以跳过此步骤。
- 用
fused_landmark_attention
替换torch.nn.functional.scaled_dot_product_attention
。
请注意,该实现依赖于最新版本的 Triton,这与最新版本的 PyTorch 存在冲突。因此,提供了一个特殊的 install_deps.sh
脚本来安装依赖项。
最后,请注意当前实现做出以下假设:
- 该实现假设 landmark 块与 Flash Attention 中用于计算注意力的块具有相同的大小。这限制了块的最大大小,因为整个 landmark 块应该适合 GPU 的本地内存。然而,使用 bfloat16 应该可以使用大小为 64 或 128 的块,这对于 landmark 块应该足够了。
- 该实现假设键和查询数量的差异是块大小的倍数。因此,在自回归生成部分中,当标记一个接一个生成时,必须应用正常注意力。在达到生成之前,该实现仍然可以用于遍历输入。 请注意,这不是一个很大的限制,因为在一次生成一个标记时,注意力矩阵只有一行,限制了 Flash Attention 的好处。
- 虽然高级实现允许将 landmark 标记放置在任何位置,但融合实现假设 landmark 标记定期放置在每个块的末尾。由于我们在推理时总是使用这种模式,这应该不会被注意到。