mamba-minimal
在一个 PyTorch 文件中简单、最小化地实现 Mamba。
特点:
- 前向和反向传播的数值输出与官方实现等效
- 简化、可读、带注释的代码
不包括:
- 速度。官方实现经过大量优化,这些优化是 Mamba 论文的核心贡献。为了可读性,我保持了大部分实现的简单性。
- 适当的参数初始化(尽管可以在不牺牲可读性的情况下添加)
演示
查看 demo.ipynb 以获取提示补全的示例。
from model import Mamba
from transformers import AutoTokenizer
model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
generate(model, tokenizer, 'Mamba is the')
Mamba 是世界上最长的毒蛇,估计长度超过 150 米。由于体型巨大和带有毒性的咬伤,Mamba 通过刺伤受害者来杀死猎物(这比单次咬伤更痛苦且效果较差)
150 米... 🫢 太可怕了!
参考文献
Mamba 架构由 Albert Gu 和 Tri Dao 在论文 Mamba: Linear-Time Sequence Modeling with Selective State Spaces 中提出。