flash-attention-minimal
使用CUDA和PyTorch对Flash Attention的最小化重新实现。 对CUDA初学者(如我自己)来说,官方实现可能相当令人生畏,因此本仓库旨在保持简小并具有教育意义。
- 整个前向传播在
flash.cu
中仅用约100行代码编写。 - 变量名遵循原始论文中的符号。
使用方法
前提条件
- PyTorch(带CUDA)
- 用于加载C++的
Ninja
基准测试
比较手动注意力和最小化flash注意力之间的实际运行时间:
python bench.py
在T4上的样例输出:
=== 分析手动注意力 ===
...
Self CPU总时间:52.389毫秒
Self CUDA总时间:52.545毫秒
=== 分析最小化flash注意力 ===
...
Self CPU总时间:11.452毫秒
Self CUDA总时间:3.908毫秒
成功提速!
我没有GPU
尝试这个在线Colab演示。
注意事项
- 没有反向传播!老实说,我发现它比前向传播复杂得多,而前向传播已足以展示使用共享内存来避免大量N^2读/写操作。
- 在内部循环中,我为输出矩阵的每一行分配一个线程。这与原始实现不同。
- 这种每行一个线程的简化使矩阵乘法变得非常慢。这可能是为什么对于更长的序列和更大的块大小,这比手动实现更慢的原因。
- Q、K、V使用float32,不同于原始实现中使用的float16。
- 块大小在编译时固定为32。
待办事项
- 添加反向传播
- 加速矩阵乘法
- 动态设置块大小