Project Icon

ThunderKittens

高效瓦片原语框架助力深度学习内核开发

ThunderKittens是一个用于开发高性能CUDA深度学习内核的框架。它基于现代GPU架构设计,通过操作16x16及以上的数据瓦片实现高效计算。框架支持张量核心、共享内存优化和异步数据传输等特性,充分利用GPU性能。ThunderKittens以简洁、可扩展和高速为设计原则,适用于各类深度学习算法的高效实现。

ThunderKittens

快速内核的瓦片原语

ThunderKittens 标志


ThunderKittens 是一个框架,旨在使用 CUDA 轻松编写快速的深度学习内核(不久后还将支持 ROCm 等其他平台)。

ThunderKittens 基于三个关键原则构建:

  1. 简单性。ThunderKittens 编写起来异常简单。
  2. 可扩展性。ThunderKittens 原生嵌入,如果你需要的功能超出了 ThunderKittens 的能力范围,它不会妨碍你自行构建。
  3. 速度。使用 ThunderKittens 编写的内核应该至少与从头编写的内核一样快 —— 特别是因为 ThunderKittens 可以在底层以"正确"的方式处理事情。我们认为我们的 Flash Attention 2 实现证明了这一点。
Flash Attention 2,但是带有小猫!

ThunderKittens 是从硬件层面构建的 —— 我们按照硅芯片的指示行事。现代 GPU 告诉我们,它们希望处理相当小的数据瓦片。GPU 并不真的是一个 1000x1000 矩阵乘法机器(即使它经常被这样使用);它是一个多核处理器,每个核心可以高效地执行约 16x16 的矩阵乘法。因此,ThunderKittens 围绕操作不小于 16x16 值的数据瓦片构建。

ThunderKittens 让一些棘手的事情变得简单,从而在现代硬件上实现高利用率。

  1. 张量核心。ThunderKittens 可以调用快速的张量核心函数,包括在 H100 GPU 上的异步 WGMMA 调用。
  2. 共享内存。我有九十九个问题,但银行冲突不是其中之一。
  3. 加载和存储。通过异步复制隐藏延迟,通过 TMA 进行地址生成。
  4. 分布式共享内存。L2 已经是过去式了。

示例:一个简单的注意力内核

以下是使用 ThunderKittens 为 RTX 4090 编写的简单 FlashAttention-2 内核示例。

#define NUM_WORKERS 16 // 此内核每个块并行使用16个工作线程,以帮助更快地发出指令。

using namespace kittens; // 为简单起见,此内核仅处理 headdim=64。此外,n 应该是 256 的倍数。
__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {

    auto warpid        = kittens::warpid();
    auto block_start   = blockIdx.x*(n*64);
    const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start;
          bf16 *_o = __o__ + block_start;

    extern __shared__ alignment_dummy __shm[]; // 这是 CUDA 共享内存
    shared_allocator al((int*)&__shm[0]);
    
    // K 和 V 存储在共享内存中 —— 这几乎是所能容纳的全部。
    st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
    st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();

    // 初始化所有寄存器瓦片。
    rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg 需要交换到 col_l
    rt_fl_1x1<> att_block;
    rt_bf_1x1<> att_block_mma;
    rt_fl_1x4<> o_reg;
    rt_fl_1x1<>::col_vec max_vec_last, max_vec; // 这些是注意力块的列向量
    rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // 这些是注意力块的列向量
    
    int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);

    for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {

        // 每个线程束加载自己的 16x64 的 Q 瓦片,然后乘以 1/sqrt(d)
        load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
        mul(q_reg, q_reg, __float2bfloat16(0.125f)); // 温度调整

        // 将 flash 注意力 L、M 和 O 寄存器置零。
        neg_infty(max_vec); // 为 Q 块清零寄存器
        zero(norm_vec);
        zero(o_reg);

        // 对已加载的这些 q 迭代 k、v
        for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {

            // 每个线程束将自己的 k、v 块加载到共享内存中
            load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
            load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
            __syncthreads(); // 我们需要确保在开始计算阶段之前所有内存都已加载

            // 现在每个线程束遍历所有子瓦片,加载它们,然后执行 flash 注意力内部算法。
            for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {

                load(k_reg, k_smem[subtile]); // 从共享内存加载 k 到寄存器

                zero(att_block); // 将 16x16 注意力瓦片置零
                mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T

                copy(norm_vec_last, norm_vec);
                copy(max_vec_last,  max_vec);

                row_max(max_vec, att_block, max_vec); // 累积到 max_vec
                sub_row(att_block, att_block, max_vec); // 从注意力中减去最大值 —— 现在所有值 <=0
                exp(att_block, att_block); // 原地对块进行指数运算。

                sub(max_vec_last, max_vec_last, max_vec); // 从旧的最大值中减去新的最大值以找到新的归一化。
                exp(max_vec_last, max_vec_last); // 对这个向量进行指数运算 —— 这是我们需要用来归一化的。
                mul(norm_vec, norm_vec, max_vec_last); // norm_vec 现在已归一化。

                row_sum(norm_vec, att_block, norm_vec); // 将新的注意力块累积到现在已重新缩放的 norm_vec 上
                div_row(att_block, att_block, norm_vec); // 现在注意力块已正确归一化

                mul(norm_vec_last, norm_vec_last, max_vec_last); // 根据新的最大值归一化先前的 norm vec
                div(norm_vec_last, norm_vec_last, norm_vec); // 根据新的范数归一化先前的 norm vec

                copy(att_block_mma, att_block); // 转换为 bf16 以用于 mma_AB

                load(v_reg, v_smem[subtile]); // 从共享内存加载 v 到寄存器。
                rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // 这是一个引用,该调用使 v_reg 失效

                mul_row(o_reg, o_reg, norm_vec_last); // 在进行 mma_AB 之前预先归一化 o_reg
                mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // 使用局部注意力@V 矩阵乘法对 o_reg 进行 mfma。
            }
            __syncthreads(); // 我们需要确保所有线程束都完成后才能开始加载下一个 kv 块
        }

        store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // 写出 o。如果 d 设为 constexpr q_reg.rows,编译器在寄存器使用上会有问题 :/
    }
}

总的来说,这是 58 行代码(不包括空白行),在 RTX 4090 上可以达到约 122 TFLOPs。(理论最大值的 74%)我们将在下一节 ThunderKittens 手册中更仔细地介绍这些原语。

库安装

要使用 Thunderkittens,你不需要对 TK 本身做太多操作。它是一个仅包含头文件的库,所以只需克隆仓库,并包含 kittens.cuh。轻松搞定。

但 ThunderKittens 确实使用了许多现代功能,因此它有相当严格的要求。

  • CUDA 12.3+。CUDA 12.1 之后的任何版本可能都能工作,但由于这些早期 CUDA 版本中的一个 bug,你可能会遇到串行化的 wgmma 管道。
  • 广泛使用 C++20 —— TK 基于概念运行。
sudo apt update
sudo apt install gcc-10 g++-10

sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 --slave /usr/bin/g++ g++ /usr/bin/g++-10

sudo apt update
sudo apt install clang-10

如果你找不到 nvcc,或者遇到环境指向错误 CUDA 版本的问题:

export CUDA_HOME=/usr/local/cuda-12/
export PATH=${CUDA_HOME}/bin:${PATH} 
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH

最后,感谢 Jordan Juravsky 整理了一份关于设置兼容 kittens 的 conda 环境的快速文档。

内核安装

要试验我们现有的 TK 内核,请在 config.py 文件中指定你感兴趣的内核,然后运行 python setup.py install

欢迎贡献新的内核!

测试

要验证你的安装,并运行 TK 相当全面的单元测试套件,只需在 tests 文件夹中运行 make -j。注意:这可能会在编译数千个内核时占用你的电脑一两分钟。

示例

要编译示例,请在根目录运行 source env.src,然后进入 examples 目录。(许多示例使用 $THUNDERKITTENS_ROOT 环境变量来定位自己并找到 src 目录。)

ThunderKittens 手册

ThunderKittens 实际上是一个相当小的库,就其提供的功能而言。

  • 数据类型:(寄存器 + 共享)*(瓦片 + 向量),所有这些都由布局、类型和大小参数化。
  • 用于操作这些对象的操作。

尽管它很简单,但如果你不了解底层的工作原理,仍可能会遇到一些棘手的问题。因此,我们建议你在开始编写内核之前好好阅读这份手册 —— 我们保证它不会太长!

NVIDIA 的编程模型

为了理解ThunderKittens,首先回顾一下NVIDIA的编程模型如何工作会有所帮助,因为NVIDIA在编写并行代码时提供了几个不同的"作用域"供考虑。

  1. 线程 -- 这是在单个数据位上执行工作的级别,如浮点乘法。一个线程每个周期可以访问最多256个32位寄存器。
  2. 线程束 -- 32个线程组成一个线程束。这是硬件发出指令的级别。它也是ThunderKittens操作的基本(和默认)作用域;大多数ThunderKittens编程都发生在这个级别。
  3. 线程束组 -- 4个线程束组成一个线程束组。这是发出异步线程束组矩阵乘累加指令的级别。(我们真希望能忽略这个级别,但不幸的是H100需要它。)相应地,许多矩阵乘法和内存操作都在线程束组级别得到支持。
  4. 块 -- N个线程束组成一个块,这是在CUDA编程模型中共享"共享内存"的级别。在ThunderKittens中,N通常是8。
  5. 网格 -- M个块组成一个网格,其中M应该等于(或略小于)GPU上SM数量的倍数,以避免尾部效应。ThunderKittens不直接操作网格作用域,除了通过帮助初始化TMA描述符。

"寄存器"对象存在于线程束级别 -- 它们的内容分布在线程束的各个线程中。寄存器对象包括:

  • 寄存器瓦片,在src/register_tile/rt.cuh中声明为kittens::rt结构。Kittens提供了一些有用的包装器 -- 例如,可以将32x16行布局的bfloat16寄存器瓦片声明为kittens::rt_bf_2x1; -- 默认情况下行布局是隐含的。
  • 寄存器向量,与寄存器瓦片相关联。它们有两种形式:列向量和行向量。列向量用于在瓦片行上进行归约或映射,而行向量在瓦片列上进行归约和映射。例如,要保存上面声明的瓦片行的和,我们可以创建一个kittens::rt_bf_2x1<>::col_vec; 相比之下,"共享"对象存在于块级别,仅位于共享内存中。

所有ThunderKittens函数都遵循一个通用的签名。与汇编语言类似(ThunderKittens本质上是一个抽象的面向瓦片的RISC指令集),每个函数的目标是第一个操作数,源操作数按顺序传递。

例如,如果我们有三个32x64浮点寄存器瓦片:kittens::rt_fl_2x4 a, b, c;,我们可以对ab进行元素级乘法并将结果存储在c中,调用如下:kittens::mul(c, a, b);

同样,如果我们想将结果存储到共享瓦片__shared__ kittens:st_bf_2x4 s;中,我们可以类似地写函数:kittens::store(s, c);

类型系统

ThunderKittens努力保护你免受自身错误的影响。特别是,ThunderKittens希望在编译时知道对象的布局,并确保它们在允许你进行操作之前是兼容的。这很重要,因为某些操作的允许布局有微妙之处,如果没有静态检查,很容易出现令人痛苦的静默失败。例如,普通的矩阵乘法要求B操作数采用列布局,而外积则要求B操作数采用行布局。

如果你被告知你认为存在的操作不存在,请仔细检查你的布局 -- 这是最常见的错误。只有在确认后才报告bug :)

作用域

默认情况下,ThunderKittens操作存在于线程束级别。换句话说,每个函数期望只由单个线程束调用,并且该单个线程束将完成函数的所有工作。如果将多个线程束分配给相同的工作,将导致未定义行为。(如果操作涉及内存移动,很可能会完全崩溃。)通常,你应该期望你的编程模式涉及在内核开始时使用kittens::warpid()实例化一个warpid,并基于该id将任务分配给数据。

然而,并非所有ThunderKittens函数都在线程束级别操作。许多重要操作,特别是WGMMA指令,需要线程束的协作组。这些操作存在于模板kittens::group<collaborative size>中。例如,wgmma指令可通过kittens::group<4>::mma_AB(或其别名kittens::warpgroup::mma_AB)获得。线程束组还可以协作加载共享内存或在共享内存中进行归约。

其他限制

ThunderKittens中的大多数操作都是纯函数式的。然而,一些操作确实有特殊限制;ThunderKittens试图通过给它们起显眼的名字来警告你。例如,寄存器瓦片转置需要可分离的参数:如果给它相同的底层寄存器作为源和目标,它会静默失败。因此,它被命名为transpose_sep

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号