attention_sinks 项目介绍
项目概况
attention_sinks
是一个开源项目,致力于通过修改预训练的大型语言模型(LLM)的注意力机制,使它们能够无缝地产生流畅的文本。其核心思路是在传统滑动窗口注意力机制的基础上进行改进,保持内存使用效率的同时成功生成流畅的文本。
基准测试结果
困惑度
项目中对比了多种方法下模型困惑度的变化,困惑度越高意味着模型生成正常语言的能力下降。下图展示了不同方法下的困惑度表现:
- transformers:显存使用线性增长,在预训练长度之后性能严重下降。
- windowed(窗口化注意力):由于在1024个token时实行窗口化,显存使用固定,但当第一个token离开窗口时性能立即下降。
- attention_sinks:由于采用窗口化并使用4个注意力汇聚token和最近的1020个token,因此显存使用固定,性能不受影响。
持续生成时的流畅性
在无限生成的测试中,使用不同方法的Llama-2-7B模型表现如下:
- transformers:在生成约1900个token后流畅性丧失,开始生成损坏的Unicode字符。
- windowed:在生成约1000个token后流畅性丧失,生成大量无意义的字符。
- attention_sinks:在长达10000个token的测试中保持流畅。
聊天模型在连续提示下的流畅性
在一系列提示的基础上测试聊天风格LLMs的生成能力,attention_sinks
显著提高了模型跨越多个提示的流畅性。然而,在某些模型如Llama-2-7B-chat-hf中,仍然存在一些流畅性问题。
实现细节
attention_sinks
项目基于Efficient Streaming Language Models with Attention Sinks(高效流式语言模型与注意力汇聚点)的研究。其主要优点包括:
- 无需重新训练即可扩展现有LLM(如Llama 2)以无缝生成流畅文本,特别适用于多步骤的LLM如聊天助手。
- 使用
attention_sinks
无需重置缓存,显存使用恒定,推理不会因为序列长度过长而变得缓慢。 - 在执行从20行前回忆值的任务时表现优异,即使已经处理了数十万行,而普通密集或窗口化注意力在处理几千个token后性能降为0%。
安装与使用
您可以使用以下命令安装attention_sinks
:
pip install attention_sinks
该项目支持多种模型,如Llama、Mistral、Falcon、MPT、GPTNeoX(Pythia)、GPT-J、Qwen、StableLM_epoch、BTLM、Yi等。使用中只需将模型类从transformers
切换到attention_sinks
。
以下是一个简单的使用示例:
from attention_sinks import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b", device_map="auto")
应用场景
Attention Sink模型最适合于需要连续生成或多回合对话的流式应用场景,比如日常助理,能够基于最近的对话提供合理的响应而无需频繁刷新缓存。
结尾
attention_sinks
项目由StreamingLLM启发并进行改进。项目致力于为多种模型和使用场景提供高效、流畅的文本生成解决方案。