Gemma 2B - 1000万上下文
Gemma 2B采用循环局部注意力机制,上下文长度可达1000万。我们的实现仅使用不到32GB的内存!
特性:
- Gemma 2B的序列长度达1000万。
- 运行内存少于32GB。
- 为cuda优化的本地推理。
- 循环局部注意力实现O(N)内存复杂度。
快速开始
**注意:**这是模型的一个非常早期的检查点。仅训练了200步。我们计划训练更多的token!
安装依赖:
pip install -r requirements.txt
从huggingface安装模型 - Huggingface模型。
python main.py
根据您的具体需求修改main.py
中的推理代码。
model_path = "./models/gemma-2b-10m"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GemmaForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16
)
prompt_text = "总结这本哈利波特的书..."
with torch.no_grad():
generated_text = generate(
model, tokenizer, prompt_text, max_length=512, temperature=0.8
)
print(generated_text)
这是如何实现的?
对于大语言模型来说,最大的内存瓶颈是KV缓存。在普通的多头注意力机制中,它呈二次增长,因此限制了序列长度的大小。
我们的方法按照InfiniAttention中概述的方式将注意力分割为局部注意力块。我们对这些局部注意力块应用循环,最终实现了1000万上下文的全局注意力。
我们的许多想法灵感来自Transformer-XL论文。
更多细节
要了解更多关于我们的动机、实现细节和背后的理论,请查看我们在medium上的技术概述。
致谢
本项目由以下人员构建: