Project Icon

ringattention

创新注意力机制大幅提升Transformer上下文处理能力

ringattention项目实现Ring Attention和Blockwise Transformers技术,显著提升Transformer模型上下文处理能力。通过跨设备分布式计算和通信重叠,模型可处理长达数千万个token的序列,无需增加开销。该技术支持causal block和cache index,为大规模语言模型训练提供高效解决方案,特别适用于超长上下文处理场景。

GPU/TPU Jax实现的环形注意力机制

本代码库提供了带有块状变换器的环形注意力机制的实现。该模型在论文《使用块状变换器的环形注意力机制实现近乎无限的上下文》和《用于大上下文模型的块状并行变换器》中有所描述。

带有块状并行变换器的环形注意力机制使得训练序列长度可以达到"设备数量"倍于BPT可能实现的长度。这是通过在多个设备上分配注意力和前馈计算,并将通信与计算重叠来实现的。由于注意力和前馈网络的块状计算,可以在不增加任何通信或计算开销的情况下,使用数千万个标记作为上下文大小进行训练。

示例用法及代码片段

首先,安装软件包:

pip install ringattention

然后,可以按如下方式导入和使用ringattentionblockwise_feedforward

from ringattention import ringattention, blockwise_feedforward

您可以通过将ringattention函数包装在shard_map中来跨多个设备分片计算。以下是如何使用带分片的ringattention函数的示例:

ring_attention_sharded = shard_map(
    partial(
        ringattention,
        axis_name="sp",
        float32_logits=True,
        cache_idx=None,
        blockwise_kwargs=dict(
            causal_block_size=1,
            deterministic=True,
            dropout_rng=None,
            attn_pdrop=0.0,
            query_chunk_size=512,
            key_chunk_size=512,
            policy=jax.checkpoint_policies.nothing_saveable,
            dtype=jax.numpy.float32,
            precision=None,
            prevent_cse=True,
        )
    ),
    mesh=LLaMAConfig.get_jax_mesh(self.config.mesh_dim),
    in_specs=(
        PS(("dp", "fsdp"), "sp", "tp", None),
        PS(("dp", "fsdp"), "sp", "tp", None),
        PS(("dp", "fsdp"), "sp", "tp", None),
        PS(("dp", "fsdp"), None, None, None),
        PS(("dp", "fsdp"), None),
    ),
    out_specs=PS(("dp", "fsdp"), "sp", "tp", None),
    check_rep=False
)
attn_output = ring_attention_sharded(xq, xk, xv, attention_bias, segment_ids)

参数说明:

  • query_chunk_sizekey_chunk_size是查询和键的块大小。选择尽可能大的值以加速计算,直到内存不足为止。

  • policy是注意力权重的检查点策略,使用jax.checkpoint_policies.nothing_saveable来启用检查点。

  • causal_block_size是块因果注意力的块大小。causal_block_size=1等同于因果注意力。

  • cache_idx是推理时的缓存索引。如果cache_idx不为None,注意力权重将被缓存并在下一次推理中重复使用。

环形注意力机制在大世界模型(LWM)中用于百万长度的视觉语言训练,在那里可以找到使用环形注意力和块状变换器的完整示例:LWM代码库

参考文献

如果您觉得我们的工作与您的研究相关,请引用:

@article{liu2023blockwise,
    title={Blockwise Parallel Transformer for Large Context Models},
    author={Liu, Hao and Abbeel, Pieter},
    journal={Advances in neural information processing systems},
    year={2023}
}
@article{liu2023ring,
    title={Ring Attention with Blockwise Transformers for Near-Infinite Context},
    author={Liu, Hao and Zaharia, Matei and Abbeel, Pieter},
    journal={arXiv preprint arXiv:2310.01889},
    year={2023}
}
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号