JAX和序列建模的完美结合

Ray

seqax

seqax简介

seqax是一个专为小到中等规模的大型语言模型(LLM)预训练研究而设计的代码库。它巧妙地将序列建模与JAX框架相结合,为研究人员提供了一个强大而灵活的工具。seqax的核心优势在于其简洁性和高效性 - 整个训练程序,包括模型实现、优化器、多主机FSDP和张量并行分区,仅用500行代码就能实现。这种精简的设计不仅使代码易于理解和修改,还能在约100个GPU或TPU上实现良好的扩展性能。

seqax logo

seqax的核心特性

seqax的设计理念是将重要信息置于显著位置,而不是隐藏在抽象和间接引用之后,或通过自动和不可预测的方式推断。这种设计理念体现在以下几个方面:

  1. 数学透明性: seqax直接实现了训练步骤中的所有数学计算,而不是调用外部库。这意味着如果你想理解或修改数学逻辑,所有内容都清晰可见。

  2. 内存管理: 所有进入模型检查点的张量都是显式的。占用大量内存的张量,包括为反向传播保存的激活值,也都是显式的。你可以直接从源代码中读取内存占用情况。

  3. 分区和通信: 所有张量和操作的分区布局都是显式的。所有芯片间通信也是显式的。

这种设计方法使得seqax成为一个透明、可控且易于理解的工具,特别适合那些希望深入了解和定制LLM训练过程的研究人员。

快速入门指南

要开始使用seqax,首先需要进行安装和环境配置:

  1. 从系统包管理器安装graphviz,例如使用brew install graphvizapt install graphviz

  2. 安装Python依赖,通常在虚拟环境中执行:python -m pip install -r requirements-cpu.txt

注意:对于GPU或TPU安装,可能需要不同的JAX和jaxlib安装方式。请参考JAX安装文档以获取详细信息。

CPU本地开发

对于开发和测试,可以在CPU上运行seqax。通常会使用合成数据集或Huggingface数据加载器,并设置XLA标志以模拟多设备环境:

XLA_FLAGS=--xla_force_host_platform_device_count=8 python -m train --config-name=local_test_synthetic +paths.model_name=synthetic_000

paths.model_name标志指定了在磁盘上写入模型检查点的子目录(在/tmp内)。每次开始新的模型运行时,通常需要更改此设置。

GPU运行

seqax提供了一系列预配置的模型大小,可以在C4数据集上使用Llama分词器进行训练。你可以在configs/目录中浏览并选择合适的配置文件。每个配置文件的顶部都列出了运行说明。

建议为每个不同的训练运行设置唯一的paths.model_name。这个路径指定了在磁盘上写入模型检查点的子目录。

性能表现

seqax在A100集群上的最近基准测试结果令人印象深刻:

单主机A100x8:

模型大小MFU (模型FLOPS利用率)
84m14
270m24
540m35
1b41.6
2b50.66

4个A100x8主机(使用InfiniBand连接):

模型大小MFU
1b32.4
2b39.0

这些数据显示,seqax能够在各种模型规模上实现良好的性能,特别是在较大模型上表现出色。

数据加载器

seqax支持两种主要的数据加载方式:

  1. 直接从Huggingface流式传输训练数据(参见示例配置)。

  2. 先将训练数据转换为预分词的磁盘格式,称为flat-tokens(参见示例配置).

从Huggingface流式传输允许快速试验不同的数据集,但它不支持在作业中断后从检查点高效恢复训练,并且在批处理边界会浪费一些数据集中的标记。相比之下,flat-tokens格式支持从检查点高效恢复训练,使用100%的标记进行训练,并且在训练期间消耗更少的CPU时间。

要预先对训练数据进行分词,可以运行huggingface_to_flat_tokens.py脚本。在现代CPU上,此脚本每分钟可处理约1亿个标记。

shardlib: 表达分区和通信的新方法

seqax引入了一个名为shardlib的新库,用于在JAX中表达分区和通信。这个库借鉴了jaxtypingeinopsequinoxshard_map的思想和风格。

shardlib的核心思想是通过类型注解和特殊的语法来表达张量的分片和通信操作。例如,为了实现完全分片的数据并行(FSDP)处理一个简单的全连接神经网络,可以这样写:

@pytree_dataclass
class Weights:
  w1: f32['in hidden1/d']
  w2: f32['hidden1 hidden2/d']
  w3: f32['hidden2/d']

@typed_shard_map
def forward_pass(x: f32[b'batch/d in'], w: Weights) -> f32[b'batch/d']:
  w1 = shardops.all_gather('in hidden1/d -> in hidden1', w.w1)
  y = jax.nn.relu(shardops.einsum_unreduced('batch/d in, in hidden1 -> batch/d hidden1', x, w1))
  w2 = shardops.all_gather('hidden1 hidden2/d -> hidden1 hidden2', w.w2)
  z = jax.nn.relu(shardops.einsum_unreduced('batch/d hidden1, hidden1 hidden2 -> batch/d hidden2', y, w2))
  w3 = shardops.all_gather('hidden2/d -> hidden2', w.w3)
  return shardops.einsum_unreduced('batch/d hidden2, hidden2 -> batch/d', z, w3)

这种方法使得分片和通信操作变得清晰可见,同时保持了代码的简洁性和可读性。

使用save_for_backward表达激活检查点

seqax提供了一种简单的方法来控制前向传播中哪些中间计算结果应该保存到HBM以供后向传播使用。这是通过save_for_backward函数实现的:

@explicit_activation_checkpointing
def forward_pass(x, w1, w2):
  y = save_for_backward(x @ w1)
  z = jax.nn.relu(z)
  return z @ w2

使用@explicit_activation_checkpointing装饰器可以改变JAX的默认策略,只保存被注解函数的参数,以及任何使用save_for_backward标记的中间结果。这种方法给予了开发者更多的控制权,可以精确地指定哪些计算结果需要保存。

性能分析

seqax在每次训练运行中都会收集并报告性能信息:

  • 两个训练步骤的时间(包括中间的数据获取)。这会输出到标准输出。
  • 这些步骤的模型FLOPS利用率(MFU)效率。同样输出到标准输出。
  • XLA性能分析。保存在模型目录的<model_dir>/plugins/profile/<date>/perfetto_trace.json.gz中。
  • 优化后的XLA计算图的SVG渲染。保存在<model_dir>/training_step_optimized_hlo_<date>.svg中。

这些详细的性能信息使得研究人员能够深入了解模型的运行情况,并进行必要的优化。

文件格式

seqax使用基于zarr的简单文件格式来存储检查点和数据集。具体规范可以查看:

这些格式设计简单而高效,便于研究人员理解和使用。

结语

seqax为LLM预训练研究提供了一个强大、灵活且高效的工具。它的设计理念强调透明性和可控性,使研究人员能够深入理解和定制训练过程的每个方面。通过结合JAX的高性能计算能力和创新的分片、通信表达方式,seqax在保持代码简洁性的同时,实现了出色的性能表现。

对于那些希望在LLM预训练领域进行深入研究的人来说,seqax无疑是一个值得关注的项目。它不仅提供了必要的工具和功能,还通过其透明的设计哲学,鼓励研究人员深入理解和改进LLM训练的每个环节。

seqax的开发得到了多方面的支持和启发,包括来自JAX团队的持续支持和建议,以及Google TPU Research Cloud的部分支持。这种协作精神和开放态度,为推动LLM研究的进步做出了重要贡献。

随着AI和深度学习技术的不断发展,像seqax这样的工具将在推动研究边界方面发挥越来越重要的作用。我们期待看到更多研究人员利用seqax进行创新实验,为LLM领域带来新的突破和见解。

🚀 如果你对LLM预训练研究感兴趣,不妨尝试使用seqax,探索其强大功能,为你的研究注入新的活力! 💡

avatar
0
0
0
相关项目
Project Cover

dopamine

Dopamine是一个用于快速原型设计强化学习算法的研究框架,旨在便于用户进行自由实验。其设计原则包括易于实验、灵活开发、紧凑可靠和结果可重复。支持的算法有DQN、C51、Rainbow、IQN和SAC,主要实现于jax。Dopamine提供了Docker容器及源码安装方法,适用于Atari和Mujoco环境,并推荐使用虚拟环境。更多信息请参阅官方文档。

Project Cover

EasyDeL

EasyDeL是一个开源框架,用于通过Jax/Flax优化机器学习模型的训练,特别适合在TPU/GPU上进行大规模部署。它支持多种模型架构和量化方法,包括Transformers、Mamba等,并提供高级训练器和API引擎。EasyDeL的架构完全可定制和透明,允许用户修改每个组件,并促进实验和社区驱动的开发。不论是前沿研究还是生产系统构建,EasyDeL都提供灵活强大的工具以满足不同需求。最新更新包括性能优化、KV缓存改进和新模型支持。

Project Cover

keras-nlp

KerasNLP 是一个兼容 TensorFlow、JAX 和 PyTorch 的自然语言处理库,提供预训练模型和低级模块。基于 Keras 3,支持 GPU 和 TPU 的微调,并可跨框架训练和序列化。设置 KERAS_BACKEND 环境变量即可切换框架,安装方便,立即体验强大 NLP 功能。

Project Cover

EasyLM

EasyLM提供了一站式解决方案,用于在JAX/Flax中预训练、微调、评估和部署大规模语言模型。通过JAX的pjit功能,可以扩展到数百个TPU/GPU加速器。基于Hugginface的transformers和datasets,EasyLM代码库易于使用和定制。支持Google Cloud TPU Pods上的多TPU/GPU和多主机训练,兼容LLaMA系列模型。推荐加入非官方的Discord社区,了解更多关于Koala聊天机器人和OpenLLaMA的详细信息及安装指南。

Project Cover

dm_pix

PIX是一个基于JAX的开源图像处理库,具备优化和并行化能力。支持通过jax.jit、jax.vmap和jax.pmap进行加速与并行处理,适用于高性能计算需求。安装便捷,只需通过pip安装后即可使用。提供丰富的示例代码,易于上手操作,同时配备完整的测试套件,确保开发环境的可靠性,并接受社区贡献。

Project Cover

penzai

Penzai是一个基于JAX的库,专为通过函数式pytree数据结构编写模型而设计,并提供丰富的工具用于可视化、修改和分析。适用于反向工程、模型组件剥离、内部激活检查、模型手术和调试等领域。Penzai包括Treescope交互式Python打印工具、JAX树和数组操作工具、声明式神经网络库及常见Transformer架构的模块化实现。该库简化了模型处理过程,为研究神经网络的内部机制与训练动态提供了支持。

Project Cover

GradCache

Gradient Cache技术突破了GPU/TPU内存限制,可以无限扩展对比学习的批处理大小。仅需一个GPU即可完成原本需要8个V100 GPU的训练,并能够用更具成本效益的高FLOP低内存系统替换大内存GPU/TPU。该项目支持Pytorch和JAX框架,并已整合至密集段落检索工具DPR。

Project Cover

dm-haiku

Haiku是一个为JAX设计的简洁神经网络库,具备面向对象编程模型和纯函数转换功能。由Sonnet的开发者创建,Haiku能简化模型参数和状态管理,并与其他JAX库无缝集成。虽然Google DeepMind建议新项目使用Flax,Haiku仍将在维护模式下持续支持,专注于修复bug和兼容性更新。

Project Cover

keras

Keras 3 提供高效的模型开发,支持计算机视觉、自然语言处理等任务。选择最快的后端(如JAX),性能提升高达350%。无缝扩展,从本地到大规模集群,适合企业和初创团队。安装简单,支持GPU,兼容tf.keras代码,避免框架锁定。

最新项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

稿定AI

稿定设计 是一个多功能的在线设计和创意平台,提供广泛的设计工具和资源,以满足不同用户的需求。从专业的图形设计师到普通用户,无论是进行图片处理、智能抠图、H5页面制作还是视频剪辑,稿定设计都能提供简单、高效的解决方案。该平台以其用户友好的界面和强大的功能集合,帮助用户轻松实现创意设计。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号