Project Icon

flash-attention

高效注意力机制加速深度学习模型训练

FlashAttention是一种高效的注意力机制实现,通过IO感知算法和内存优化提升计算速度并降低内存消耗。它支持NVIDIA和AMD GPU,适用于多种深度学习框架。最新的FlashAttention-3版本针对H100 GPU进行了优化。该项目提供Python接口,可集成到现有模型中,有助于加速大规模深度学习模型的训练过程。

FlashAttention

本仓库提供了以下论文中 FlashAttention 和 FlashAttention-2 的官方实现。

FlashAttention: 具有 IO 感知的快速且内存高效的精确注意力机制 Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré 论文:https://arxiv.org/abs/2205.14135 IEEE Spectrum 关于我们使用 FlashAttention 提交 MLPerf 2.0 基准测试的文章FlashAttention

FlashAttention-2: 更快的注意力机制,具有更好的并行性和工作分区 Tri Dao 论文:https://tridao.me/publications/flash2/flash2.pdf FlashAttention-2

使用情况

我们很高兴看到 FlashAttention 在发布后短时间内被广泛采用。这个页面包含了 FlashAttention 被使用的部分列表。

FlashAttention 和 FlashAttention-2 可以免费使用和修改(参见 LICENSE)。如果您使用它,请引用并致谢 FlashAttention。

FlashAttention-3 beta 版本发布

FlashAttention-3 针对 Hopper GPU(如 H100)进行了优化。

博客文章:https://tridao.me/blog/2024/flash3/

论文:https://tridao.me/publications/flash3/flash3.pdf

FlashAttention-3 在 H100 80GB SXM5 上使用 FP16 的速度提升

这是一个 beta 版本,用于在我们将其与仓库其他部分集成之前进行测试/基准测试。

目前发布的内容:

  • FP16 前向和后向传播

即将在未来几天/下周发布:

  • BF16
  • 可变长度(FP16,BF16)
  • FP8 前向传播

要求:H100 / H800 GPU,CUDA >= 12.3。

安装方法:

cd hopper
python setup.py install

运行测试:

export PYTHONPATH=$PWD
pytest -q -s test_flash_attn.py

安装和功能

要求:

  • CUDA 工具包或 ROCm 工具包
  • PyTorch 1.12 及以上版本
  • packaging Python 包(pip install packaging
  • ninja Python 包(pip install ninja)*
  • Linux。从 v2.3.2 开始可能适用于 Windows(我们看到了一些正面报告),但 Windows 编译仍需更多测试。如果您有关于如何为 Windows 设置预构建 CUDA wheels 的想法,请通过 GitHub issue 联系我们。

* 确保正确安装并运行 ninja(例如,ninja --version 然后 echo $? 应返回退出码 0)。如果不是(有时 ninja --version 然后 echo $? 返回非零退出码),请卸载然后重新安装 ninjapip uninstall -y ninja && pip install ninja)。没有 ninja,编译可能需要很长时间(2小时),因为它不使用多个 CPU 核心。使用 ninja 在 64 核机器上使用 CUDA 工具包编译需要 3-5 分钟。

安装方法:

pip install flash-attn --no-build-isolation

或者您可以从源代码编译:

python setup.py install

如果您的机器 RAM 少于 96GB 且有很多 CPU 核心,ninja 可能会运行太多并行编译作业,可能耗尽 RAM。要限制并行编译作业的数量,您可以设置环境变量 MAX_JOBS

MAX_JOBS=4 pip install flash-attn --no-build-isolation

接口: src/flash_attention_interface.py

NVIDIA CUDA 支持

要求:

  • CUDA 11.6 及以上版本

我们推荐使用 Nvidia 的 Pytorch 容器,它包含安装 FlashAttention 所需的所有工具。

FlashAttention-2 与 CUDA 目前支持:

  1. Ampere、Ada 或 Hopper GPU(例如 A100、RTX 3090、RTX 4090、H100)。Turing GPU(T4、RTX 2080)的支持即将推出,目前请对 Turing GPU 使用 FlashAttention 1.x。
  2. fp16 和 bf16 数据类型(bf16 需要 Ampere、Ada 或 Hopper GPU)。
  3. 所有头维度最高到 256。头维度 > 192 的反向传播需要 A100/A800 或 H100/H800。从 flash-attn 2.5.5 开始,头维度 256 的反向传播现在可以在消费级 GPU 上工作(如果没有 dropout)。

AMD ROCm 支持

ROCm 版本使用 composable_kernel 作为后端。它提供了 FlashAttention-2 的实现。

要求:

  • ROCm 6.0 及以上版本

我们推荐使用 ROCm 的 Pytorch 容器,它包含安装 FlashAttention 所需的所有工具。

FlashAttention-2 与 ROCm 目前支持:

  1. MI200 或 MI300 GPU。
  2. fp16 和 bf16 数据类型
  3. 前向传播的头维度最高到 256。反向传播的头维度最高到 128。

如何使用 FlashAttention

主要函数实现了缩放点积注意力(softmax(Q @ K^T * softmax_scale) @ V):

from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
                          window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""在评估时 dropout_p 应设置为 0.0
如果 Q、K、V 已经堆叠成一个张量,这个函数比在 Q、K、V 上调用 flash_attn_func 更快,
因为反向传播避免了 Q、K、V 梯度的显式连接。
如果 window_size != (-1, -1),实现滑动窗口局部注意力。位置 i 的查询
只会关注 [i - window_size[0], i + window_size[1]] 范围内的键。
参数:
    qkv: (batch_size, seqlen, 3, nheads, headdim)
    dropout_p: float。dropout 概率。
    softmax_scale: float。应用 softmax 之前 QK^T 的缩放。
        默认为 1 / sqrt(headdim)。
    causal: bool。是否应用因果注意力掩码(例如,用于自回归建模)。
    window_size: (left, right)。如果不是 (-1, -1),实现滑动窗口局部注意力。
    alibi_slopes: (nheads,) 或 (batch_size, nheads),fp32。
        在查询 i 和键 j 的注意力分数上添加 (-alibi_slope * |i - j|) 的偏置。
    deterministic: bool。是否使用确定性实现的反向传播,
        稍微慢一些并使用更多内存。前向传播总是确定性的。
返回:
    out: (batch_size, seqlen, nheads, headdim)。
"""
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
                window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""在评估时 dropout_p 应设置为 0.0
通过传入头数少于 Q 的 KV 支持多查询和分组查询注意力(MQA/GQA)。
注意,Q 中的头数必须能被 KV 中的头数整除。
例如,如果 Q 有 6 个头,K、V 有 2 个头,Q 的头 0、1、2 将关注 K、V 的头 0,
Q 的头 3、4、5 将关注 K、V 的头 1。
如果 window_size != (-1, -1),实现滑动窗口局部注意力。位置 i 的查询
只会关注 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内的键。

参数:
    q: (batch_size, seqlen, nheads, headdim)
    k: (batch_size, seqlen, nheads_k, headdim)
    v: (batch_size, seqlen, nheads_k, headdim)
    dropout_p: float。dropout 概率。
    softmax_scale: float。应用 softmax 之前 QK^T 的缩放。
        默认为 1 / sqrt(headdim)。
    causal: bool。是否应用因果注意力掩码(例如,用于自回归建模)。
    window_size: (left, right)。如果不是 (-1, -1),实现滑动窗口局部注意力。
    alibi_slopes: (nheads,) 或 (batch_size, nheads),fp32。
        在查询 i 和键 j 的注意力分数上添加 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 的偏置。
    deterministic: bool。是否使用确定性实现的反向传播,
        稍微慢一些并使用更多内存。前向传播总是确定性的。
返回:
    out: (batch_size, seqlen, nheads, headdim)。
"""
def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
    rotary_cos=None,
    rotary_sin=None,
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
    cache_batch_idx: Optional[torch.Tensor] = None,
    block_table: Optional[torch.Tensor] = None,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 表示无限上下文窗口
    rotary_interleaved=True,
    alibi_slopes=None,
):
    """
    如果 k 和 v 不为 None,k_cache 和 v_cache 将会被 k 和 v 中的新值*原地*更新。这对增量解码很有用:
    你可以传入上一步缓存的键/值,用当前步骤的新键/值更新它们,并在一个内核中对更新后的缓存进行注意力计算。

    如果你传入 k / v,你必须确保缓存足够大以容纳新值。例如,KV 缓存可以预先分配最大序列长度,
    你可以使用 cache_seqlens 来跟踪批次中每个序列的当前序列长度。

    如果传入了 rotary_cos 和 rotary_sin,还会应用旋转嵌入。键 @k 将在索引 cache_seqlens、cache_seqlens + 1 等处被 rotary_cos 和 rotary_sin 旋转。
    如果是因果的或局部的(即 window_size != (-1, -1)),查询 @q 将在索引 cache_seqlens、cache_seqlens + 1 等处被 rotary_cos 和 rotary_sin 旋转。
    如果既不是因果的也不是局部的,查询 @q 将仅在索引 cache_seqlens 处被 rotary_cos 和 rotary_sin 旋转(即我们认为 @q 中的所有 token 都位于 cache_seqlens 位置)。

    参见 tests/test_flash_attn.py::test_flash_attn_kvcache 了解如何使用此函数的示例。

    通过传入比 Q 更少头数的 KV 来支持多查询和分组查询注意力(MQA/GQA)。注意,Q 中的头数必须能被 KV 中的头数整除。
    例如,如果 Q 有 6 个头而 K、V 有 2 个头,Q 的头 0、1、2 将注意 K、V 的头 0,Q 的头 3、4、5 将注意 K、V 的头 1。

    如果 causal=True,因果掩码将对齐到注意力矩阵的右下角。
    例如,如果 seqlen_q = 2 且 seqlen_k = 5,因果掩码(1 = 保留,0 = 屏蔽)为:
        1 1 1 1 0
        1 1 1 1 1
    如果 seqlen_q = 5 且 seqlen_k = 2,因果掩码为:
        0 0
        0 0
        0 0
        1 0
        1 1
    如果掩码的一行全为零,输出将为零。

    如果 window_size != (-1, -1),实现滑动窗口局部注意力。位置 i 的查询将只注意
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内的键(包括边界)。

    注意:不支持反向传播。

    参数:
        q: (batch_size, seqlen, nheads, headdim)
        k_cache: 如果没有 block_table,形状为 (batch_size_cache, seqlen_cache, nheads_k, headdim),
            如果有 block_table(即分页 KV 缓存),形状为 (num_blocks, page_block_size, nheads_k, headdim)
            page_block_size 必须是 256 的倍数。
        v_cache: 如果没有 block_table,形状为 (batch_size_cache, seqlen_cache, nheads_k, headdim),
            如果有 block_table(即分页 KV 缓存),形状为 (num_blocks, page_block_size, nheads_k, headdim)
        k [可选]: (batch_size, seqlen_new, nheads_k, headdim)。如果不为 None,我们将 k 与 k_cache 连接,
            从 cache_seqlens 指定的索引开始。
        v [可选]: (batch_size, seqlen_new, nheads_k, headdim)。与 k 类似。
        rotary_cos [可选]: (seqlen_ro, rotary_dim / 2)。如果不为 None,我们对 k 和 q 应用旋转嵌入。
            仅在传入 k 和 v 时适用。rotary_dim 必须能被 16 整除。
        rotary_sin [可选]: (seqlen_ro, rotary_dim / 2)。与 rotary_cos 类似。
        cache_seqlens: int 或 (batch_size,),dtype torch.int32。KV 缓存的序列长度。
        block_table [可选]: (batch_size, max_num_blocks_per_seq),dtype torch.int32。
        cache_batch_idx: (batch_size,),dtype torch.int32。用于索引 KV 缓存的索引。
            如果为 None,我们假设批次索引为 [0, 1, 2, ..., batch_size - 1]。
            如果索引不是唯一的,且提供了 k 和 v,更新到缓存的值可能来自任何重复的索引。
        softmax_scale: float。在应用 softmax 之前对 QK^T 的缩放。
            默认为 1 / sqrt(headdim)。
        causal: bool。是否应用因果注意力掩码(例如,用于自回归建模)。
        window_size: (left, right)。如果不是 (-1, -1),实现滑动窗口局部注意力。
        rotary_interleaved: bool。仅在传入 rotary_cos 和 rotary_sin 时适用。
            如果为 True,旋转嵌入将组合维度 0 & 1,2 & 3 等。如果为 False,
            旋转嵌入将组合维度 0 & rotary_dim / 2,1 & rotary_dim / 2 + 1
            (即 GPT-NeoX 风格)。
        alibi_slopes: (nheads,) 或 (batch_size, nheads),fp32。
            一个偏置 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
            被添加到查询 i 和键 j 的注意力分数中。

    返回:
        out: (batch_size, seqlen, nheads, headdim)。
    """

要查看这些函数如何在多头注意力层中使用(包括 QKV 投影、输出投影),请参见 MHA 实现

更新日志

2.0:完全重写,速度提升 2 倍

从 FlashAttention (1.x) 升级到 FlashAttention-2

这些函数已被重命名:

  • flash_attn_unpadded_func -> flash_attn_varlen_func
  • flash_attn_unpadded_qkvpacked_func -> flash_attn_varlen_qkvpacked_func
  • flash_attn_unpadded_kvpacked_func -> flash_attn_varlen_kvpacked_func

如果同一批次中的输入具有相同的序列长度,使用这些函数更简单快速:

flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)

2.1:更改 causal 标志的行为

如果 seqlen_q != seqlen_k 且 causal=True,因果掩码将对齐到注意力矩阵的右下角,而不是左上角。

例如,如果 seqlen_q = 2 且 seqlen_k = 5,因果掩码(1 = 保留,0 = 屏蔽)为:
v2.0:
1 0 0 0 0
1 1 0 0 0
v2.1:
1 1 1 1 0
1 1 1 1 1

如果 seqlen_q = 5 且 seqlen_k = 2,因果掩码为:
v2.0:
1 0
1 1
1 1
1 1
1 1
v2.1:
0 0
0 0
0 0
1 0
1 1
如果掩码的一行全为零,输出将为零。

2.2:优化推理

优化推理(迭代解码)当查询具有非常小的序列长度时(例如,查询序列长度 = 1)。这里的瓶颈是尽可能快地加载 KV 缓存,我们将加载分散到不同的线程块中,并使用单独的内核来合并结果。

查看具有更多推理特性的函数 flash_attn_with_kvcache(执行旋转嵌入,原地更新 KV 缓存)。

感谢 xformers 团队,特别是 Daniel Haziza,对这次合作的贡献。

2.3:局部(即滑动窗口)注意力

实现滑动窗口注意力(即局部注意力)。感谢 Mistral AI 特别是 Timothée Lacroix 的这项贡献。滑动窗口在 Mistral 7B 模型中得到了应用。

2.4:ALiBi(带线性偏置的注意力),确定性反向传播。

实现 ALiBi(Press 等人,2021)。感谢来自 Kakao Brain 的 Sanghun Cho 的这项贡献。

实现确定性反向传播。感谢来自美团的工程师们的这项贡献。

2.5:分页 KV 缓存。

支持分页 KV 缓存(即 PagedAttention)。 感谢 @beginlner 的这项贡献。

2.6:软上限。

支持带软上限的注意力,如 Gemma-2 和 Grok 模型中使用的。 感谢 @Narsil 和 @lucidrains 的这项贡献。

性能

我们展示了使用FlashAttention相对于PyTorch标准注意力机制在不同序列长度和GPU上的预期加速(正向 + 反向传播组合)和内存节省情况(加速取决于内存带宽 - 我们在较慢的GPU内存上看到更多加速)。

目前我们有以下GPU的基准测试:

A100

我们使用以下参数展示FlashAttention的加速:

  • 头部维度64或128,隐藏维度2048(即32或16个头)。
  • 序列长度512、1k、2k、4k、8k、16k。
  • 批量大小设置为16k / 序列长度。

加速

A100 80GB SXM5上使用FP16/BF16的FlashAttention加速

内存

FlashAttention内存使用

我们在此图中展示了内存节省情况(注意,无论是否使用dropout或掩码,内存占用都相同)。 内存节省与序列长度成正比 -- 因为标准注意力的内存与序列长度呈二次方关系,而FlashAttention的内存与序列长度呈线性关系。 在序列长度为2K时,我们看到10倍的内存节省,在4K时为20倍。 因此,FlashAttention可以扩展到更长的序列长度。

H100

H100 SXM5上使用FP16/BF16的FlashAttention加速

完整模型代码和训练脚本

我们已发布完整的GPT模型实现。 我们还提供了其他层的优化实现(如MLP、LayerNorm、交叉熵损失、旋转嵌入)。总体上,与Huggingface的基线实现相比,这使训练速度提高了3-5倍,在每个A100上达到225 TFLOPs/秒,相当于72%的模型FLOPs利用率(我们不需要任何激活检查点)。

我们还包含了一个训练脚本,用于在Openwebtext上训练GPT2和在The Pile上训练GPT3。

FlashAttention的Triton实现

Phil Tillet(OpenAI)有一个FlashAttention在Triton中的实验性实现: https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py

由于Triton是比CUDA更高级的语言,可能更容易理解和实验。Triton实现中使用的符号也更接近我们论文中使用的符号。

我们还有一个支持注意力偏置(如ALiBi)的Triton实验性实现: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py

测试

我们测试FlashAttention是否在一定数值容差范围内产生与参考实现相同的输出和梯度。特别是,我们检查FlashAttention的最大数值误差是否最多为Pytorch基线实现数值误差的两倍(对于不同的头部维度、输入数据类型、序列长度、因果/非因果)。

运行测试:

pytest -q -s tests/test_flash_attn.py

遇到问题时

FlashAttention-2的这个新版本已在几个GPT风格的模型上进行了测试,主要在A100 GPU上。

如果遇到错误,请开启GitHub问题!

测试

运行测试:

pytest tests/test_flash_attn_ck.py

引用

如果您使用此代码库,或者认为我们的工作有价值,请引用:

@inproceedings{dao2022flashattention,
  title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
  author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2022}
}
@inproceedings{dao2023flashattention2,
  title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
  author={Dao, Tri},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2024}
}
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号