注意力体操馆
注意力体操馆是一组用于处理flex-attention的有用工具和示例集合
🎯 特性 | 🚀 入门 | 💻 使用 | 🛠️ 开发 | 🤝 贡献 | ⚖️ 许可证
📖 概述
本仓库旨在提供一个使用FlexAttention API实验各种注意力机制的练习场。它包括不同注意力变体的实现、性能比较以及帮助研究人员和开发人员探索和优化模型中注意力机制的实用功能。
🎯 特性
- 使用FlexAttention实现各种注意力机制
- 用于创建和组合注意力掩码的实用函数
- 在实际场景中使用FlexAttention的示例
🚀 入门
先决条件
- PyTorch(2.5版或更高)
安装
git clone https://github.com/pytorch-labs/attention-gym.git
cd attention-gym
pip install .
💻 使用
使用注意力体操馆有两种主要方式:
-
运行示例脚本:项目中的许多文件可以直接执行以演示其功能:
python attn_gym/masks/document_mask.py
这些脚本通常会生成可视化效果,帮助你理解注意力机制。
-
在你的项目中导入:你可以通过导入注意力体操馆的组件在自己的工作中使用它们:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask from attn_gym.masks import generate_sliding_window # 在代码中使用导入的函数 sliding_window_mask = generate_sliding_window(window_size=1024) block_mask = create_block_mask(mask_mod, 1, 1, S, S, device=device) out = flex_attention(query, key, value, block_mask=block_mask)
要查看在实际场景中使用FlexAttention的全面示例,请浏览examples/
目录。这些端到端实现展示了如何将各种注意力机制集成到你的模型中。
注意
注意力体操馆正在积极开发中,目前我们不提供任何向后兼容性保证。API和功能可能会在不同版本之间发生变化。我们建议在你的项目中固定使用特定版本,并在升级时仔细审查变更。
📁 结构
注意力体操馆的组织方式便于探索注意力机制:
🔍 关键位置
attn_gym.masks
:创建BlockMasks
的示例attn_gym.mods
:创建score_mods
的示例examples/
:使用FlexAttention的详细实现
🛠️ 开发
安装开发依赖
pip install -e ".[dev]"
安装pre-commit钩子
pre-commit install
🤝 贡献
我们欢迎对注意力体操馆的贡献,尤其是新的掩码或分数修改器!以下是如何贡献的方法:
贡献修改器
- 在attn_gym/masks/目录中为mask_mods创建新文件,或在attn_gym/mods/目录中为score_mods创建新文件。
- 实现你的函数,并添加一个简单的主函数来展示你的新函数。
- 更新
attn_gym/*/__init__.py
文件以包含你的新函数。 - [可选] 在examples/目录中添加一个使用你新函数的端到端示例。
更多详情请参阅CONTRIBUTING.md。
⚖️ 许可证
attention-gym根据BSD 3-Clause许可证发布。