Project Icon

flash-linear-attention

Triton实现的高效线性注意力模型库

Flash Linear Attention是一个基于Triton实现的线性注意力模型库。该项目集成了RetNet、GLA和Based等多种先进模型,实现了高效的token混合和文本生成。兼容Hugging Face Transformers库,提供预训练模型、评估工具和基准测试,为线性注意力技术的研究和应用提供了便利。

快速线性注意力

hf_model | Discord

本仓库旨在提供一系列基于Triton的高效实现,用于最先进的线性注意力模型。欢迎任何拉取请求!

image

模型

日期模型标题论文代码FLA实现
2023-07RetNet (@MSRA@THU)保留网络:大型语言模型的Transformer继任者[arxiv][官方] [RetNet]代码
2023-12GLA (@MIT@IBM)具有硬件高效训练的门控线性注意力Transformer[arxiv][官方]代码
2023-12Based (@Stanford@Hazyresearch)一个教育性且有效的序列混合器[博客][官方]代码
2024-01Rebased具有可学习核函数的线性Transformer是更好的上下文模型[arxiv][官方]代码
2021-02Delta Net线性Transformer实际上是快速权重编程器[arxiv][官方]代码
2023-09Hedgehog (@HazyResearch)刺猬和豪猪:具有Softmax模仿的表达性线性注意力openreview代码
2023-10PolySketchFormer (@CMU@Google)通过多项式核草图实现快速Transformerarxiv待完成
2023-07TransnormerLLM一种更快更好的大型语言模型,采用改进的TransNormer(@上海人工智能实验室)openreview arxiv[官方] [Lightning2]待完成
2023-05RWKV-v4 (@BlinkDL)为Transformer时代重新发明RNNarxiv[官方]待完成
2023-10GateLoop用于序列建模的完全数据控制的线性递归openreview arxiv[官方] [jax]待完成
2021-10ABC (@UW)具有有界内存控制的注意力arxiv代码
2023-09VQ-transformer通过向量量化实现线性时间Transformerarxiv[官方]待完成
2023-09HGRN用于序列建模的分层门控递归神经网络openreview[官方]代码
2024-04HGRN2HGRN2:具有状态扩展的门控线性RNNarxiv[官方]代码
2024-04RWKV6鹰和雀鹀:具有矩阵值状态和动态递归的RWKVarxiv[官方]代码
2024-06SambaSamba:用于高效无限上下文语言建模的简单混合状态空间模型arxiv[官方]代码
2024-05Mamba2Transformer是SSM:通过结构化状态空间对偶性实现广义模型和高效算法arxiv[官方]代码

安装

需满足以下要求:

  • PyTorch >= 2.0
  • Triton >=2.2
  • einops 由于fla目前正在积极开发中,暂时没有提供已发布的软件包。 如果您确实需要使用fla的操作/模块并考虑进一步探索,可以通过以下方式从源代码安装软件包
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

或者使用子模块管理fla

git submodule add https://github.com/sustcsonglin/flash-linear-attention.git 3rdparty/flash-linear-attention
ln -s 3rdparty/flash-linear-attention/fla fla

[!注意] 如果您没有使用Triton v2.2或其每夜版本,请注意FusedChunk实现可能存在潜在问题,详见此问题。 您可以运行测试python tests/test_fused_chunk.py来检查您的版本是否受到类似编译器问题的影响。 虽然我们为Triton<=2.1提供了一些修复方案,但请注意这些可能会导致性能下降。

对于Triton 2.2和更早版本(最高2.1),您可以可靠地使用Chunk版本(隐藏状态具体化到HBM中)。 经过仔细优化,这个版本在大多数情况下通常能提供高性能。

使用方法

令牌混合

我们在fla.layers中提供了"令牌混合"线性注意力层供您使用。 您可以用其他线性注意力层替换模型中的标准多头注意力层。 使用示例如下:

>>> import torch
>>> from fla.layers import MultiScaleRetention
>>> batch_size, num_heads, seq_len, hidden_size,  = 32, 4, 2048, 1024
>>> device, dtype = 'cuda:0', torch.bfloat16
>>> retnet = MultiScaleRetention(hidden_size=hidden_size, num_heads=num_heads).to(device=device, dtype=dtype)
>>> x = torch.randn(batch_size, seq_len, hidden_size).to(device=device, dtype=dtype)
>>> y, *_ = retnet(x)
>>> y.shape
torch.Size([32, 2048, 1024])

我们提供了与🤗 Transformers库兼容的模型实现。 以下是如何从fla中的默认配置初始化GLA模型的示例:

>>> from fla.models import GLAConfig
>>> from transformers import AutoModel
>>> config = GLAConfig()
>>> config
GLAConfig {
  "attn_mode": "fused_chunk",
  "bos_token_id": 1,
  "clamp_min": null,
  "conv_size": 4,
  "eos_token_id": 2,
  "expand_k": 0.5,
  "expand_v": 1,
  "fuse_cross_entropy": true,
  "fuse_norm": true,
  "hidden_act": "swish",
  "hidden_ratio": 4,
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": null,
  "max_position_embeddings": 2048,
  "model_type": "gla",
  "num_heads": 4,
  "num_hidden_layers": 24,
  "rms_norm_eps": 1e-06,
  "share_conv_kernel": true,
  "tie_word_embeddings": false,
  "transformers_version": "4.39.1",
  "use_cache": true,
  "use_gk": true,
  "use_gv": false,
  "use_short_conv": false,
  "vocab_size": 32000
}

>>> AutoModel.from_config(config)
GLAModel(
  (embed_tokens): Embedding(32000, 2048)
  (layers): ModuleList(
    (0-23): 24 x GLABlock(
      (attn_norm): RMSNorm()
      (attn): GatedLinearAttention(
        (gate_fn): SiLU()
        (q_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (g_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (gk_proj): Sequential(
          (0): Linear(in_features=2048, out_features=16, bias=False)
          (1): Linear(in_features=16, out_features=1024, bias=True)
        )
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (g_norm_swish_gate): FusedRMSNormSwishGate()
      )
      (mlp_norm): RMSNorm()
      (mlp): GLAMLP(
        (gate_proj): Linear(in_features=2048, out_features=11264, bias=False)
        (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
        (act_fn): SiLU()
      )
    )
  )
  (norm): RMSNorm()
)

生成

成功预训练模型后,就可以使用🤗文本生成API来生成文本。 以下是一个生成示例:

>>> import fla
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> name = 'fla-hub/gla-1.3B-100B'
>>> tokenizer = AutoTokenizer.from_pretrained(name)
>>> model = AutoModelForCausalLM.from_pretrained(name).cuda()
>>> input_prompt = "Power goes with permanence. Impermanence is impotence. And rotation is castration."
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
>>> outputs = model.generate(input_ids, max_length=64)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

我们还提供了一个简单的脚本这里用于基准测试生成速度。 只需运行:

$ python -m benchmarks.benchmark_generation \
  --path 'fla-hub/gla-1.3B-100B' \
  --repetition_penalty 2. \
  --prompt="Hello everyone, I'm Songlin Yang"

提示:
Hello everyone, I'm Songlin Yang
生成:
Hello everyone, I'm Songlin Yang.
I am a 20 year old girl from China who is currently studying in the United States of America for my Master degree and also working as an English teacher at school here on campus since last summer (1st semester). My main goal to be able do well with this course so that we can have

提示长度:10,生成长度:64
总提示处理 + 解码时间:4593ms

所有当前可用的预训练模型都可以在fla-hub中找到。

>>> from huggingface_hub import list_models
>>> for model in list_models(author='fla-hub'): print(model.id)

评估

lm-evaluation-harness库允许您轻松执行(零样本)模型评估。 按照以下步骤使用此库:

  1. 按照他们的说明安装lm_eval

  2. 运行评估:

$ PATH='fla-hub/gla-1.3B-100B'
$ python -m evals.harness --model hf \
    --model_args pretrained=$PATH,dtype=bfloat16 \
    --tasks wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,sciq,copa,openbookqa \
    --batch_size 64 \
    --num_fewshot 0 \
    --device cuda \
    --show_config                  

我们已经使fla与hf风格的评估兼容,您可以调用evals.harness来完成评估。 运行上述命令将提供GLA论文中报告的任务结果。

[!提示] 如果您将lm-evaluation-harness作为外部库使用,却发现(几乎)没有可用的任务,在调用lm_eval.evaluate()lm_eval.simple_evaluate()之前,只需运行以下命令来加载库的默认任务!

>>> from lm_eval.tasks import TaskManager; TaskManager().initialize_tasks()

基准测试

我们将基于Triton的RetNet实现与基于CUDA的FlashAttention2进行了比较,使用批量大小为8、32个头部和128的头部维度,在不同的序列长度下进行测试。这些测试在单个A100 80GB GPU上进行,如下图所示:

# 你可能需要先通过 `pip install -e .` 安装 `fla` 以启用其导入
$ python benchmark_retention.py
性能:
   seq_len  fused_chunk_fwd  chunk_fwd  parallel_fwd  fused_chunk_fwdbwd  chunk_fwdbwd  parallel_fwdbwd  flash_fwd  flash_fwdbwd
0    128.0         0.093184   0.185344      0.067584            1.009664      1.591296         1.044480   0.041984      0.282624
1    256.0         0.165888   0.219136      0.126976            1.024000      1.596928         1.073152   0.074752      0.413696
2    512.0         0.308224   0.397312      0.265216            1.550336      1.603584         1.301504   0.156672      0.883712
3   1024.0         0.603136   0.747520      0.706560            3.044864      3.089408         3.529728   0.467968      2.342912
4   2048.0         1.191424   1.403904      2.141184            6.010880      6.059008        11.009024   1.612800      7.135232
5   4096.0         2.377728   2.755072      7.392256           11.932672     11.938816        37.792770   5.997568     24.435200
6   8192.0         4.750336   5.491712     26.402817           23.759359     23.952385       141.014023  22.682114     90.619904
7  16384.0         9.591296  10.870784    101.262337           47.666176     48.745472       539.853821  91.346947    346.318848

性能

线性注意力的不同形式

关于线性注意力不同形式的硬件考虑,请参考GLA论文的第2.3节。

  • 并行:自注意力风格的计算,时间复杂度为O(L^2),具有序列并行性。
  • 融合递归:递归计算,时间复杂度为O(L)。隐藏状态在共享内存中即时计算,无需物化到全局内存(详见此论文的算法1)。这节省了大量I/O成本,应该是速度比较的强基准。
  • 融合分块:分块计算,时间复杂度为O(LC),其中C是块大小。隐藏状态同样即时计算,不物化到全局内存。这个版本通常比融合递归更好,因为可以使用张量核心进行序列级"归约",而融合递归完全无法使用张量核心。注意,此实现中没有序列级并行性,因此不适合非常小的批量大小设置。应比并行分块更节省内存。
  • 并行分块:具有序列并行性的分块计算。需要为每个块将隐藏状态物化到全局内存。需要适当设置C以获得良好性能,因为当C小时,需要加载/存储到全局内存的隐藏状态太多;当C太大时,浮点运算量高。推荐的C值为[64, 128, 256]。

引用

如果您觉得这个仓库有用,请考虑引用我们的工作:

@article{yang2024delta,
  title   = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length}, 
  author  = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim},
  journal = {arXiv preprint arXiv:2406.06484},
  year    = {2024},
}

@article{yang2023gated,
  title   = {Gated Linear Attention Transformers with Hardware-Efficient Training},
  author  = {Yang, Songlin and Wang, Bailin and Shen, Yikang and Panda, Rameswar and Kim, Yoon},
  journal = {arXiv preprint arXiv:2312.06635},
  year    = {2023}
}

@software{yang2024fla,
  title  = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism},
  author = {Yang, Songlin and Zhang, Yu},
  url    = {https://github.com/sustcsonglin/flash-linear-attention},
  month  = jan,
  year   = {2024}
}
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号