ThunderKittens
快速内核的瓦片原语
ThunderKittens 是一个框架,旨在使用 CUDA 轻松编写快速的深度学习内核(不久后还将支持 ROCm 等其他平台)。
ThunderKittens 基于三个关键原则构建:
- 简单性。ThunderKittens 编写起来异常简单。
- 可扩展性。ThunderKittens 原生嵌入,如果你需要的功能超出了 ThunderKittens 的能力范围,它不会妨碍你自行构建。
- 速度。使用 ThunderKittens 编写的内核应该至少与从头编写的内核一样快 —— 特别是因为 ThunderKittens 可以在底层以"正确"的方式处理事情。我们认为我们的 Flash Attention 2 实现证明了这一点。
ThunderKittens 是从硬件层面构建的 —— 我们按照硅芯片的指示行事。现代 GPU 告诉我们,它们希望处理相当小的数据瓦片。GPU 并不真的是一个 1000x1000 矩阵乘法机器(即使它经常被这样使用);它是一个多核处理器,每个核心可以高效地执行约 16x16 的矩阵乘法。因此,ThunderKittens 围绕操作不小于 16x16 值的数据瓦片构建。
ThunderKittens 让一些棘手的事情变得简单,从而在现代硬件上实现高利用率。
- 张量核心。ThunderKittens 可以调用快速的张量核心函数,包括在 H100 GPU 上的异步 WGMMA 调用。
- 共享内存。我有九十九个问题,但银行冲突不是其中之一。
- 加载和存储。通过异步复制隐藏延迟,通过 TMA 进行地址生成。
- 分布式共享内存。L2 已经是过去式了。
示例:一个简单的注意力内核
以下是使用 ThunderKittens 为 RTX 4090 编写的简单 FlashAttention-2 内核示例。
#define NUM_WORKERS 16 // 此内核每个块并行使用16个工作线程,以帮助更快地发出指令。
using namespace kittens; // 为简单起见,此内核仅处理 headdim=64。此外,n 应该是 256 的倍数。
__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {
auto warpid = kittens::warpid();
auto block_start = blockIdx.x*(n*64);
const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start;
bf16 *_o = __o__ + block_start;
extern __shared__ alignment_dummy __shm[]; // 这是 CUDA 共享内存
shared_allocator al((int*)&__shm[0]);
// K 和 V 存储在共享内存中 —— 这几乎是所能容纳的全部。
st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
// 初始化所有寄存器瓦片。
rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg 需要交换到 col_l
rt_fl_1x1<> att_block;
rt_bf_1x1<> att_block_mma;
rt_fl_1x4<> o_reg;
rt_fl_1x1<>::col_vec max_vec_last, max_vec; // 这些是注意力块的列向量
rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // 这些是注意力块的列向量
int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);
for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {
// 每个线程束加载自己的 16x64 的 Q 瓦片,然后乘以 1/sqrt(d)
load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
mul(q_reg, q_reg, __float2bfloat16(0.125f)); // 温度调整
// 将 flash 注意力 L、M 和 O 寄存器置零。
neg_infty(max_vec); // 为 Q 块清零寄存器
zero(norm_vec);
zero(o_reg);
// 对已加载的这些 q 迭代 k、v
for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {
// 每个线程束将自己的 k、v 块加载到共享内存中
load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
__syncthreads(); // 我们需要确保在开始计算阶段之前所有内存都已加载
// 现在每个线程束遍历所有子瓦片,加载它们,然后执行 flash 注意力内部算法。
for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {
load(k_reg, k_smem[subtile]); // 从共享内存加载 k 到寄存器
zero(att_block); // 将 16x16 注意力瓦片置零
mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T
copy(norm_vec_last, norm_vec);
copy(max_vec_last, max_vec);
row_max(max_vec, att_block, max_vec); // 累积到 max_vec
sub_row(att_block, att_block, max_vec); // 从注意力中减去最大值 —— 现在所有值 <=0
exp(att_block, att_block); // 原地对块进行指数运算。
sub(max_vec_last, max_vec_last, max_vec); // 从旧的最大值中减去新的最大值以找到新的归一化。
exp(max_vec_last, max_vec_last); // 对这个向量进行指数运算 —— 这是我们需要用来归一化的。
mul(norm_vec, norm_vec, max_vec_last); // norm_vec 现在已归一化。
row_sum(norm_vec, att_block, norm_vec); // 将新的注意力块累积到现在已重新缩放的 norm_vec 上
div_row(att_block, att_block, norm_vec); // 现在注意力块已正确归一化
mul(norm_vec_last, norm_vec_last, max_vec_last); // 根据新的最大值归一化先前的 norm vec
div(norm_vec_last, norm_vec_last, norm_vec); // 根据新的范数归一化先前的 norm vec
copy(att_block_mma, att_block); // 转换为 bf16 以用于 mma_AB
load(v_reg, v_smem[subtile]); // 从共享内存加载 v 到寄存器。
rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // 这是一个引用,该调用使 v_reg 失效
mul_row(o_reg, o_reg, norm_vec_last); // 在进行 mma_AB 之前预先归一化 o_reg
mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // 使用局部注意力@V 矩阵乘法对 o_reg 进行 mfma。
}
__syncthreads(); // 我们需要确保所有线程束都完成后才能开始加载下一个 kv 块
}
store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // 写出 o。如果 d 设为 constexpr q_reg.rows,编译器在寄存器使用上会有问题 :/
}
}
总的来说,这是 58 行代码(不包括空白行),在 RTX 4090 上可以达到约 122 TFLOPs。(理论最大值的 74%)我们将在下一节 ThunderKittens 手册中更仔细地介绍这些原语。
库安装
要使用 Thunderkittens,你不需要对 TK 本身做太多操作。它是一个仅包含头文件的库,所以只需克隆仓库,并包含 kittens.cuh。轻松搞定。
但 ThunderKittens 确实使用了许多现代功能,因此它有相当严格的要求。
- CUDA 12.3+。CUDA 12.1 之后的任何版本可能都能工作,但由于这些早期 CUDA 版本中的一个 bug,你可能会遇到串行化的 wgmma 管道。
- 广泛使用 C++20 —— TK 基于概念运行。
sudo apt update
sudo apt install gcc-10 g++-10
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 --slave /usr/bin/g++ g++ /usr/bin/g++-10
sudo apt update
sudo apt install clang-10
如果你找不到 nvcc,或者遇到环境指向错误 CUDA 版本的问题:
export CUDA_HOME=/usr/local/cuda-12/
export PATH=${CUDA_HOME}/bin:${PATH}
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH
最后,感谢 Jordan Juravsky 整理了一份关于设置兼容 kittens 的 conda 环境的快速文档。
内核安装
要试验我们现有的 TK 内核,请在 config.py
文件中指定你感兴趣的内核,然后运行 python setup.py install
。
欢迎贡献新的内核!
测试
要验证你的安装,并运行 TK 相当全面的单元测试套件,只需在 tests 文件夹中运行 make -j
。注意:这可能会在编译数千个内核时占用你的电脑一两分钟。
示例
要编译示例,请在根目录运行 source env.src
,然后进入 examples 目录。(许多示例使用 $THUNDERKITTENS_ROOT
环境变量来定位自己并找到 src 目录。)
ThunderKittens 手册
ThunderKittens 实际上是一个相当小的库,就其提供的功能而言。
- 数据类型:(寄存器 + 共享)*(瓦片 + 向量),所有这些都由布局、类型和大小参数化。
- 用于操作这些对象的操作。
尽管它很简单,但如果你不了解底层的工作原理,仍可能会遇到一些棘手的问题。因此,我们建议你在开始编写内核之前好好阅读这份手册 —— 我们保证它不会太长!
NVIDIA 的编程模型
为了理解ThunderKittens,首先回顾一下NVIDIA的编程模型如何工作会有所帮助,因为NVIDIA在编写并行代码时提供了几个不同的"作用域"供考虑。
- 线程 -- 这是在单个数据位上执行工作的级别,如浮点乘法。一个线程每个周期可以访问最多256个32位寄存器。
- 线程束 -- 32个线程组成一个线程束。这是硬件发出指令的级别。它也是ThunderKittens操作的基本(和默认)作用域;大多数ThunderKittens编程都发生在这个级别。
- 线程束组 -- 4个线程束组成一个线程束组。这是发出异步线程束组矩阵乘累加指令的级别。(我们真希望能忽略这个级别,但不幸的是H100需要它。)相应地,许多矩阵乘法和内存操作都在线程束组级别得到支持。
- 块 -- N个线程束组成一个块,这是在CUDA编程模型中共享"共享内存"的级别。在ThunderKittens中,N通常是8。
- 网格 -- M个块组成一个网格,其中M应该等于(或略小于)GPU上SM数量的倍数,以避免尾部效应。ThunderKittens不直接操作网格作用域,除了通过帮助初始化TMA描述符。
"寄存器"对象存在于线程束级别 -- 它们的内容分布在线程束的各个线程中。寄存器对象包括:
- 寄存器瓦片,在
src/register_tile/rt.cuh
中声明为kittens::rt
结构。Kittens提供了一些有用的包装器 -- 例如,可以将32x16行布局的bfloat16寄存器瓦片声明为kittens::rt_bf_2x1;
-- 默认情况下行布局是隐含的。 - 寄存器向量,与寄存器瓦片相关联。它们有两种形式:列向量和行向量。列向量用于在瓦片行上进行归约或映射,而行向量在瓦片列上进行归约和映射。例如,要保存上面声明的瓦片行的和,我们可以创建一个
kittens::rt_bf_2x1<>::col_vec;
相比之下,"共享"对象存在于块级别,仅位于共享内存中。
所有ThunderKittens函数都遵循一个通用的签名。与汇编语言类似(ThunderKittens本质上是一个抽象的面向瓦片的RISC指令集),每个函数的目标是第一个操作数,源操作数按顺序传递。
例如,如果我们有三个32x64浮点寄存器瓦片:kittens::rt_fl_2x4 a, b, c;
,我们可以对a
和b
进行元素级乘法并将结果存储在c
中,调用如下:kittens::mul(c, a, b);
。
同样,如果我们想将结果存储到共享瓦片__shared__ kittens:st_bf_2x4 s;
中,我们可以类似地写函数:kittens::store(s, c);
。
类型系统
ThunderKittens努力保护你免受自身错误的影响。特别是,ThunderKittens希望在编译时知道对象的布局,并确保它们在允许你进行操作之前是兼容的。这很重要,因为某些操作的允许布局有微妙之处,如果没有静态检查,很容易出现令人痛苦的静默失败。例如,普通的矩阵乘法要求B操作数采用列布局,而外积则要求B操作数采用行布局。
如果你被告知你认为存在的操作不存在,请仔细检查你的布局 -- 这是最常见的错误。只有在确认后才报告bug :)
作用域
默认情况下,ThunderKittens操作存在于线程束级别。换句话说,每个函数期望只由单个线程束调用,并且该单个线程束将完成函数的所有工作。如果将多个线程束分配给相同的工作,将导致未定义行为。(如果操作涉及内存移动,很可能会完全崩溃。)通常,你应该期望你的编程模式涉及在内核开始时使用kittens::warpid()
实例化一个warpid
,并基于该id将任务分配给数据。
然而,并非所有ThunderKittens函数都在线程束级别操作。许多重要操作,特别是WGMMA指令,需要线程束的协作组。这些操作存在于模板kittens::group<collaborative size>
中。例如,wgmma指令可通过kittens::group<4>::mma_AB
(或其别名kittens::warpgroup::mma_AB
)获得。线程束组还可以协作加载共享内存或在共享内存中进行归约。
其他限制
ThunderKittens中的大多数操作都是纯函数式的。然而,一些操作确实有特殊限制;ThunderKittens试图通过给它们起显眼的名字来警告你。例如,寄存器瓦片转置需要可分离的参数:如果给它相同的底层寄存器作为源和目标,它会静默失败。因此,它被命名为transpose_sep
。