本仓库旨在提供一系列基于Triton的高效实现,用于最先进的线性注意力模型。欢迎任何拉取请求!
模型
日期 | 模型 | 标题 | 论文 | 代码 | FLA实现 |
---|---|---|---|---|---|
2023-07 | RetNet (@MSRA@THU) | 保留网络:大型语言模型的Transformer继任者 | [arxiv] | [官方] [RetNet] | 代码 |
2023-12 | GLA (@MIT@IBM) | 具有硬件高效训练的门控线性注意力Transformer | [arxiv] | [官方] | 代码 |
2023-12 | Based (@Stanford@Hazyresearch) | 一个教育性且有效的序列混合器 | [博客] | [官方] | 代码 |
2024-01 | Rebased | 具有可学习核函数的线性Transformer是更好的上下文模型 | [arxiv] | [官方] | 代码 |
2021-02 | Delta Net | 线性Transformer实际上是快速权重编程器 | [arxiv] | [官方] | 代码 |
2023-09 | Hedgehog (@HazyResearch) | 刺猬和豪猪:具有Softmax模仿的表达性线性注意力 | openreview | 代码 | |
2023-10 | PolySketchFormer (@CMU@Google) | 通过多项式核草图实现快速Transformer | arxiv | 待完成 | |
2023-07 | TransnormerLLM | 一种更快更好的大型语言模型,采用改进的TransNormer(@上海人工智能实验室) | openreview arxiv | [官方] [Lightning2] | 待完成 |
2023-05 | RWKV-v4 (@BlinkDL) | 为Transformer时代重新发明RNN | arxiv | [官方] | 待完成 |
2023-10 | GateLoop | 用于序列建模的完全数据控制的线性递归 | openreview arxiv | [官方] [jax] | 待完成 |
2021-10 | ABC (@UW) | 具有有界内存控制的注意力 | arxiv | 代码 | |
2023-09 | VQ-transformer | 通过向量量化实现线性时间Transformer | arxiv | [官方] | 待完成 |
2023-09 | HGRN | 用于序列建模的分层门控递归神经网络 | openreview | [官方] | 代码 |
2024-04 | HGRN2 | HGRN2:具有状态扩展的门控线性RNN | arxiv | [官方] | 代码 |
2024-04 | RWKV6 | 鹰和雀鹀:具有矩阵值状态和动态递归的RWKV | arxiv | [官方] | 代码 |
2024-06 | Samba | Samba:用于高效无限上下文语言建模的简单混合状态空间模型 | arxiv | [官方] | 代码 |
2024-05 | Mamba2 | Transformer是SSM:通过结构化状态空间对偶性实现广义模型和高效算法 | arxiv | [官方] | 代码 |
安装
需满足以下要求:
- PyTorch >= 2.0
- Triton >=2.2
- einops
由于
fla
目前正在积极开发中,暂时没有提供已发布的软件包。 如果您确实需要使用fla
的操作/模块并考虑进一步探索,可以通过以下方式从源代码安装软件包
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention
或者使用子模块管理fla
git submodule add https://github.com/sustcsonglin/flash-linear-attention.git 3rdparty/flash-linear-attention
ln -s 3rdparty/flash-linear-attention/fla fla
[!注意] 如果您没有使用Triton v2.2或其每夜版本,请注意
FusedChunk
实现可能存在潜在问题,详见此问题。 您可以运行测试python tests/test_fused_chunk.py
来检查您的版本是否受到类似编译器问题的影响。 虽然我们为Triton<=2.1提供了一些修复方案,但请注意这些可能会导致性能下降。对于Triton 2.2和更早版本(最高2.1),您可以可靠地使用
Chunk
版本(隐藏状态具体化到HBM中)。 经过仔细优化,这个版本在大多数情况下通常能提供高性能。
使用方法
令牌混合
我们在fla.layers
中提供了"令牌混合"线性注意力层供您使用。
您可以用其他线性注意力层替换模型中的标准多头注意力层。
使用示例如下:
>>> import torch
>>> from fla.layers import MultiScaleRetention
>>> batch_size, num_heads, seq_len, hidden_size, = 32, 4, 2048, 1024
>>> device, dtype = 'cuda:0', torch.bfloat16
>>> retnet = MultiScaleRetention(hidden_size=hidden_size, num_heads=num_heads).to(device=device, dtype=dtype)
>>> x = torch.randn(batch_size, seq_len, hidden_size).to(device=device, dtype=dtype)
>>> y, *_ = retnet(x)
>>> y.shape
torch.Size([32, 2048, 1024])
我们提供了与🤗 Transformers库兼容的模型实现。
以下是如何从fla
中的默认配置初始化GLA模型的示例:
>>> from fla.models import GLAConfig
>>> from transformers import AutoModel
>>> config = GLAConfig()
>>> config
GLAConfig {
"attn_mode": "fused_chunk",
"bos_token_id": 1,
"clamp_min": null,
"conv_size": 4,
"eos_token_id": 2,
"expand_k": 0.5,
"expand_v": 1,
"fuse_cross_entropy": true,
"fuse_norm": true,
"hidden_act": "swish",
"hidden_ratio": 4,
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": null,
"max_position_embeddings": 2048,
"model_type": "gla",
"num_heads": 4,
"num_hidden_layers": 24,
"rms_norm_eps": 1e-06,
"share_conv_kernel": true,
"tie_word_embeddings": false,
"transformers_version": "4.39.1",
"use_cache": true,
"use_gk": true,
"use_gv": false,
"use_short_conv": false,
"vocab_size": 32000
}
>>> AutoModel.from_config(config)
GLAModel(
(embed_tokens): Embedding(32000, 2048)
(layers): ModuleList(
(0-23): 24 x GLABlock(
(attn_norm): RMSNorm()
(attn): GatedLinearAttention(
(gate_fn): SiLU()
(q_proj): Linear(in_features=2048, out_features=1024, bias=False)
(k_proj): Linear(in_features=2048, out_features=1024, bias=False)
(v_proj): Linear(in_features=2048, out_features=2048, bias=False)
(g_proj): Linear(in_features=2048, out_features=2048, bias=False)
(gk_proj): Sequential(
(0): Linear(in_features=2048, out_features=16, bias=False)
(1): Linear(in_features=16, out_features=1024, bias=True)
)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(g_norm_swish_gate): FusedRMSNormSwishGate()
)
(mlp_norm): RMSNorm()
(mlp): GLAMLP(
(gate_proj): Linear(in_features=2048, out_features=11264, bias=False)
(down_proj): Linear(in_features=5632, out_features=2048, bias=False)
(act_fn): SiLU()
)
)
)
(norm): RMSNorm()
)
生成
成功预训练模型后,就可以使用🤗文本生成API来生成文本。 以下是一个生成示例:
>>> import fla
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> name = 'fla-hub/gla-1.3B-100B'
>>> tokenizer = AutoTokenizer.from_pretrained(name)
>>> model = AutoModelForCausalLM.from_pretrained(name).cuda()
>>> input_prompt = "Power goes with permanence. Impermanence is impotence. And rotation is castration."
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
>>> outputs = model.generate(input_ids, max_length=64)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
我们还提供了一个简单的脚本这里用于基准测试生成速度。 只需运行:
$ python -m benchmarks.benchmark_generation \
--path 'fla-hub/gla-1.3B-100B' \
--repetition_penalty 2. \
--prompt="Hello everyone, I'm Songlin Yang"
提示:
Hello everyone, I'm Songlin Yang
生成:
Hello everyone, I'm Songlin Yang.
I am a 20 year old girl from China who is currently studying in the United States of America for my Master degree and also working as an English teacher at school here on campus since last summer (1st semester). My main goal to be able do well with this course so that we can have
提示长度:10,生成长度:64
总提示处理 + 解码时间:4593ms
所有当前可用的预训练模型都可以在fla-hub
中找到。
>>> from huggingface_hub import list_models
>>> for model in list_models(author='fla-hub'): print(model.id)
评估
lm-evaluation-harness库允许您轻松执行(零样本)模型评估。 按照以下步骤使用此库:
-
按照他们的说明安装
lm_eval
。 -
运行评估:
$ PATH='fla-hub/gla-1.3B-100B'
$ python -m evals.harness --model hf \
--model_args pretrained=$PATH,dtype=bfloat16 \
--tasks wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,sciq,copa,openbookqa \
--batch_size 64 \
--num_fewshot 0 \
--device cuda \
--show_config
我们已经使fla
与hf风格的评估兼容,您可以调用evals.harness来完成评估。
运行上述命令将提供GLA论文中报告的任务结果。
[!提示] 如果您将
lm-evaluation-harness
作为外部库使用,却发现(几乎)没有可用的任务,在调用lm_eval.evaluate()
或lm_eval.simple_evaluate()
之前,只需运行以下命令来加载库的默认任务!
>>> from lm_eval.tasks import TaskManager; TaskManager().initialize_tasks()
基准测试
我们将基于Triton的RetNet实现与基于CUDA的FlashAttention2进行了比较,使用批量大小为8、32个头部和128的头部维度,在不同的序列长度下进行测试。这些测试在单个A100 80GB GPU上进行,如下图所示:
# 你可能需要先通过 `pip install -e .` 安装 `fla` 以启用其导入
$ python benchmark_retention.py
性能:
seq_len fused_chunk_fwd chunk_fwd parallel_fwd fused_chunk_fwdbwd chunk_fwdbwd parallel_fwdbwd flash_fwd flash_fwdbwd
0 128.0 0.093184 0.185344 0.067584 1.009664 1.591296 1.044480 0.041984 0.282624
1 256.0 0.165888 0.219136 0.126976 1.024000 1.596928 1.073152 0.074752 0.413696
2 512.0 0.308224 0.397312 0.265216 1.550336 1.603584 1.301504 0.156672 0.883712
3 1024.0 0.603136 0.747520 0.706560 3.044864 3.089408 3.529728 0.467968 2.342912
4 2048.0 1.191424 1.403904 2.141184 6.010880 6.059008 11.009024 1.612800 7.135232
5 4096.0 2.377728 2.755072 7.392256 11.932672 11.938816 37.792770 5.997568 24.435200
6 8192.0 4.750336 5.491712 26.402817 23.759359 23.952385 141.014023 22.682114 90.619904
7 16384.0 9.591296 10.870784 101.262337 47.666176 48.745472 539.853821 91.346947 346.318848
线性注意力的不同形式
关于线性注意力不同形式的硬件考虑,请参考GLA论文的第2.3节。
并行
:自注意力风格的计算,时间复杂度为O(L^2),具有序列并行性。融合递归
:递归计算,时间复杂度为O(L)。隐藏状态在共享内存中即时计算,无需物化到全局内存(详见此论文的算法1)。这节省了大量I/O成本,应该是速度比较的强基准。融合分块
:分块计算,时间复杂度为O(LC),其中C是块大小。隐藏状态同样即时计算,不物化到全局内存。这个版本通常比融合递归更好,因为可以使用张量核心进行序列级"归约",而融合递归完全无法使用张量核心。注意,此实现中没有序列级并行性,因此不适合非常小的批量大小设置。应比并行分块更节省内存。并行分块
:具有序列并行性的分块计算。需要为每个块将隐藏状态物化到全局内存。需要适当设置C以获得良好性能,因为当C小时,需要加载/存储到全局内存的隐藏状态太多;当C太大时,浮点运算量高。推荐的C值为[64, 128, 256]。
引用
如果您觉得这个仓库有用,请考虑引用我们的工作:
@article{yang2024delta,
title = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length},
author = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim},
journal = {arXiv preprint arXiv:2406.06484},
year = {2024},
}
@article{yang2023gated,
title = {Gated Linear Attention Transformers with Hardware-Efficient Training},
author = {Yang, Songlin and Wang, Bailin and Shen, Yikang and Panda, Rameswar and Kim, Yoon},
journal = {arXiv preprint arXiv:2312.06635},
year = {2023}
}
@software{yang2024fla,
title = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism},
author = {Yang, Songlin and Zhang, Yu},
url = {https://github.com/sustcsonglin/flash-linear-attention},
month = jan,
year = {2024}
}