邻域注意力扩展
给您附近的邻域带来关注力!
NATTEN是一个开源项目,致力于为邻域注意力提供快速实现,这是一种滑动窗口自注意力机制。
如果您还不熟悉邻域注意力,请参考我们的论文,或观看我们在CVPR 2023的YouTube视频。
要了解更多关于我们基于GEMM和融合的邻域注意力内核的信息,请参考我们的新预印本Faster Neighborhood Attention。
新功能:融合的邻域注意力现在支持反向传播!
我们发布了融合的邻域注意力(FNA)反向内核和接口,这意味着您现在可以更快更有效地训练基于邻域注意力的模型。
FNA可以被视为Flash Attention和FMHA等方法的泛化,从back-to-back矩阵乘法扩展到back-to-back张量-张量收缩,并内置了邻域注意力掩码。这通过从不存储注意力张量到全局内存来加速邻域注意力,这不仅减少了全局内存占用,还减少了内存带宽瓶颈。
我们强烈建议您在开始使用FNA之前,请参考FNA快速入门或融合与非融合NA指南,因为接口、内存布局和功能集可能与NATTEN中的所有非融合操作有所不同。
入门指南
NATTEN支持PyTorch 2.0及以后版本,以及Python 3.8及以上版本。 Python 3.12仅在torch >= 2.2.0时支持。
较早版本的NATTEN支持python >= 3.7 and torch >= 1.8。
请参考安装说明了解您的操作系统和硬件加速器是否与NATTEN兼容。
功能可用性
问题领域 | CPU后端 | CUDA后端 |
---|---|---|
1D | 朴素 | 朴素、GEMM、FNA |
2D | 朴素 | 朴素、GEMM、FNA |
3D | 朴素 | 朴素、FNA |
CPU
问题领域 | CPU后端 | 因果掩码 | 变化参数 | 相对位置偏差 | 自动梯度支持 |
---|---|---|---|---|---|
1D | 朴素 | ✔ | ✔ | ✔ | 前向和反向 |
2D | 朴素 | ✔ | ✔ | ✔ | 前向和反向 |
3D | 朴素 | ✔ | ✔ | ✔ | 前向和反向 |
注意:
- 前向自动梯度还不支持相对位置偏差和因果掩码。
- 当任意轴启用因果掩码时,还不支持相对位置偏差。
CUDA
问题空间 | CUDA 后端 | 因果遮蔽 | 参数变化 | 相对位置偏置 | Autograd 支持 | 最小架构 |
---|---|---|---|---|---|---|
1D | naive | ✓ | ✓ | ✓ | 正向和反向模式 | SM35 |
2D | naive | ✓ | ✓ | ✓ | 正向和反向模式 | SM35 |
3D | naive | ✓ | ✓ | ✓ | 正向和反向模式 | SM35 |
1D | gemm | - | - | ✓ | 正向和反向模式 | SM70 |
2D | gemm | - | - | ✓ | 正向和反向模式 | SM70 |
1D | fna | ✓ | ✓ | ✓ | 反向模式 | SM50 |
2D | fna | ✓ | ✓ | ✓ | 反向模式 | SM50 |
3D | fna | ✓ | ✓ | ✓ | 反向模式 | SM50 |
注意:
- FP16 内核只在 SM50 及以上可用*,BF16 需要 SM80 及以上。
- Naive FP16 内核只在 SM60 及以上可用。
- FNA FP16 内核只在 SM50 及以上可用。
- GEMM 后端在 SM70 和 SM75 上只能执行 FP16。
- Tiled 仅实现了三分之一的操作,仅针对 2D 问题实现,并且要求头维度 = 32。
- 正向模式 autograd 尚不支持相对位置偏置和因果遮蔽。
- 当任何轴启用因果遮蔽时,相对位置偏置也不受支持。
- 反向传播过程中 FNA 不支持相对位置偏置。
可能不再持续开发或改进的功能:
- 相对位置偏置
- 有更好的替代方案,不需要显式偏置注意力权重矩阵,同时性能更佳,准确性也更好。
- GEMM based 内核
- 因为 FNA 覆盖了比我们无融合的 GEMM 内核更多的功能,而且我们知道它是更好的解决方案(请参阅《Faster Neighborhood Attention》了解详情),我们不打算扩展或改进这些内核。
- 这包括对参数变化、因果遮蔽和 3D 问题的支持。
许可证
NATTEN 发布在 MIT 许可证之下。
引用
@misc{hassani2024faster,
title = {Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level},
author = {Ali Hassani and Wen-Mei Hwu and Humphrey Shi},
year = 2024,
url = {https://arxiv.org/abs/2403.04690},
eprint = {2403.04690},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
@inproceedings{hassani2023neighborhood,
title = {Neighborhood Attention Transformer},
author = {Ali Hassani and Steven Walton and Jiachen Li and Shen Li and Humphrey Shi},
year = 2023,
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}
}
@misc{hassani2022dilated,
title = {Dilated Neighborhood Attention Transformer},
author = {Ali Hassani and Humphrey Shi},
year = 2022,
url = {https://arxiv.org/abs/2209.15001},
eprint = {2209.15001},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
致谢
我们感谢 NVIDIA 以及 CUTLASS 项目及其团队在创建和开源 CUTLASS 方面的努力。我们也要感谢 Haicheng Wu 提供的宝贵反馈和意见,这促成了基于 GEMM 的 NA 的创建。 我们也感谢 Meta 和 xFormers 团队提供的 FMHA 内核,这正是我们融合邻域注意力内核的基础。 我们感谢 PyTorch 项目及其团队。