seqax简介
seqax是一个专为小到中等规模的大型语言模型(LLM)预训练研究而设计的代码库。它巧妙地将序列建模与JAX框架相结合,为研究人员提供了一个强大而灵活的工具。seqax的核心优势在于其简洁性和高效性 - 整个训练程序,包括模型实现、优化器、多主机FSDP和张量并行分区,仅用500行代码就能实现。这种精简的设计不仅使代码易于理解和修改,还能在约100个GPU或TPU上实现良好的扩展性能。
seqax的核心特性
seqax的设计理念是将重要信息置于显著位置,而不是隐藏在抽象和间接引用之后,或通过自动和不可预测的方式推断。这种设计理念体现在以下几个方面:
-
数学透明性: seqax直接实现了训练步骤中的所有数学计算,而不是调用外部库。这意味着如果你想理解或修改数学逻辑,所有内容都清晰可见。
-
内存管理: 所有进入模型检查点的张量都是显式的。占用大量内存的张量,包括为反向传播保存的激活值,也都是显式的。你可以直接从源代码中读取内存占用情况。
-
分区和通信: 所有张量和操作的分区布局都是显式的。所有芯片间通信也是显式的。
这种设计方法使得seqax成为一个透明、可控且易于理解的工具,特别适合那些希望深入了解和定制LLM训练过程的研究人员。
快速入门指南
要开始使用seqax,首先需要进行安装和环境配置:
-
从系统包管理器安装
graphviz
,例如使用brew install graphviz
或apt install graphviz
。 -
安装Python依赖,通常在虚拟环境中执行:
python -m pip install -r requirements-cpu.txt
。
注意:对于GPU或TPU安装,可能需要不同的JAX和jaxlib安装方式。请参考JAX安装文档以获取详细信息。
CPU本地开发
对于开发和测试,可以在CPU上运行seqax。通常会使用合成数据集或Huggingface数据加载器,并设置XLA标志以模拟多设备环境:
XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m train --config-name=local_test_synthetic +paths.model_name=synthetic_000
paths.model_name
标志指定了在磁盘上写入模型检查点的子目录(在/tmp
内)。每次开始新的模型运行时,通常需要更改此设置。
GPU运行
seqax提供了一系列预配置的模型大小,可以在C4数据集上使用Llama分词器进行训练。你可以在configs/
目录中浏览并选择合适的配置文件。每个配置文件的顶部都列出了运行说明。
建议为每个不同的训练运行设置唯一的paths.model_name
。这个路径指定了在磁盘上写入模型检查点的子目录。
性能表现
seqax在A100集群上的最近基准测试结果令人印象深刻:
单主机A100x8:
模型大小 | MFU (模型FLOPS利用率) |
---|---|
84m | 14 |
270m | 24 |
540m | 35 |
1b | 41.6 |
2b | 50.66 |
4个A100x8主机(使用InfiniBand连接):
模型大小 | MFU |
---|---|
1b | 32.4 |
2b | 39.0 |
这些数据显示,seqax能够在各种模型规模上实现良好的性能,特别是在较大模型上表现出色。
数据加载器
seqax支持两种主要的数据加载方式:
-
直接从Huggingface流式传输训练数据(参见示例配置)。
-
先将训练数据转换为预分词的磁盘格式,称为flat-tokens(参见示例配置).
从Huggingface流式传输允许快速试验不同的数据集,但它不支持在作业中断后从检查点高效恢复训练,并且在批处理边界会浪费一些数据集中的标记。相比之下,flat-tokens格式支持从检查点高效恢复训练,使用100%的标记进行训练,并且在训练期间消耗更少的CPU时间。
要预先对训练数据进行分词,可以运行huggingface_to_flat_tokens.py脚本。在现代CPU上,此脚本每分钟可处理约1亿个标记。
shardlib: 表达分区和通信的新方法
seqax引入了一个名为shardlib的新库,用于在JAX中表达分区和通信。这个库借鉴了jaxtyping、einops、equinox和shard_map的思想和风格。
shardlib的核心思想是通过类型注解和特殊的语法来表达张量的分片和通信操作。例如,为了实现完全分片的数据并行(FSDP)处理一个简单的全连接神经网络,可以这样写:
@pytree_dataclass
class Weights:
w1: f32['in hidden1/d']
w2: f32['hidden1 hidden2/d']
w3: f32['hidden2/d']
@typed_shard_map
def forward_pass(x: f32[b'batch/d in'], w: Weights) -> f32[b'batch/d']:
w1 = shardops.all_gather('in hidden1/d -> in hidden1', w.w1)
y = jax.nn.relu(shardops.einsum_unreduced('batch/d in, in hidden1 -> batch/d hidden1', x, w1))
w2 = shardops.all_gather('hidden1 hidden2/d -> hidden1 hidden2', w.w2)
z = jax.nn.relu(shardops.einsum_unreduced('batch/d hidden1, hidden1 hidden2 -> batch/d hidden2', y, w2))
w3 = shardops.all_gather('hidden2/d -> hidden2', w.w3)
return shardops.einsum_unreduced('batch/d hidden2, hidden2 -> batch/d', z, w3)
这种方法使得分片和通信操作变得清晰可见,同时保持了代码的简洁性和可读性。
使用save_for_backward
表达激活检查点
seqax提供了一种简单的方法来控制前向传播中哪些中间计算结果应该保存到HBM以供后向传播使用。这是通过save_for_backward
函数实现的:
@explicit_activation_checkpointing
def forward_pass(x, w1, w2):
y = save_for_backward(x @ w1)
z = jax.nn.relu(z)
return z @ w2
使用@explicit_activation_checkpointing
装饰器可以改变JAX的默认策略,只保存被注解函数的参数,以及任何使用save_for_backward
标记的中间结果。这种方法给予了开发者更多的控制权,可以精确地指定哪些计算结果需要保存。
性能分析
seqax在每次训练运行中都会收集并报告性能信息:
- 两个训练步骤的时间(包括中间的数据获取)。这会输出到标准输出。
- 这些步骤的模型FLOPS利用率(MFU)效率。同样输出到标准输出。
- XLA性能分析。保存在模型目录的
<model_dir>/plugins/profile/<date>/perfetto_trace.json.gz
中。 - 优化后的XLA计算图的SVG渲染。保存在
<model_dir>/training_step_optimized_hlo_<date>.svg
中。
这些详细的性能信息使得研究人员能够深入了解模型的运行情况,并进行必要的优化。
文件格式
seqax使用基于zarr的简单文件格式来存储检查点和数据集。具体规范可以查看:
这些格式设计简单而高效,便于研究人员理解和使用。
结语
seqax为LLM预训练研究提供了一个强大、灵活且高效的工具。它的设计理念强调透明性和可控性,使研究人员能够深入理解和定制训练过程的每个方面。通过结合JAX的高性能计算能力和创新的分片、通信表达方式,seqax在保持代码简洁性的同时,实现了出色的性能表现。
对于那些希望在LLM预训练领域进行深入研究的人来说,seqax无疑是一个值得关注的项目。它不仅提供了必要的工具和功能,还通过其透明的设计哲学,鼓励研究人员深入理解和改进LLM训练的每个环节。
seqax的开发得到了多方面的支持和启发,包括来自JAX团队的持续支持和建议,以及Google TPU Research Cloud的部分支持。这种协作精神和开放态度,为推动LLM研究的进步做出了重要贡献。
随着AI和深度学习技术的不断发展,像seqax这样的工具将在推动研究边界方面发挥越来越重要的作用。我们期待看到更多研究人员利用seqax进行创新实验,为LLM领域带来新的突破和见解。
🚀 如果你对LLM预训练研究感兴趣,不妨尝试使用seqax,探索其强大功能,为你的研究注入新的活力! 💡