Medusa 项目介绍
Medusa 是一个简单的框架,旨在通过多解码头来加速大语言模型(LLM)的生成过程。该项目的目标是使加速技术更加普及,并在不引入新模型的情况下,通过同一模型训练多个解码头来实现这一目标。通过这一创新,Medusa 解决了普遍加速技术中的三大痛点:对良好草案模型的需求、系统复杂性以及使用基于采样的生成时的低效问题。
多解码头的实现
Medusa 的核心理念是为现有的 LLM 增加多个“解码头”,以同时预测多个未来词汇。在这一过程中,Medusa 不会对原始模型进行更改,而是仅在训练期间微调新的解码头。这些解码头在生成时,各自为相应位置生成多个可能的词汇选项。然后,这些选项通过一种基于树的注意力机制进行结合和处理,最终采用常规的接受方案选出最长的合理前缀,以用于后续解码。
优化与支持
在最初版本中,Medusa 重点优化了批量大小为1的环境,这是本地模型托管中最常用的设置。在此配置下,Medusa 在各种 Vicuna 模型上实现了大约2倍的速度提升。项目团队正在积极拓展 Medusa 的能力,计划将 Medusa 集成到更多的推理框架中,以实现更高的性能提升并将其推广到更广泛的应用场景。
新功能与版本
更新版本(即 Medusa-2)增加了对全模型训练的支持,提供了一种特别的训练方案,该方案具备在保持原始模型性能的同时,增加预测能力。此外,Medusa 也增加了自蒸馏支持,使其能够在无需原始训练数据的情况下,加入到任何经过微调的 LLM 中。
安装与使用
Medusa 的安装推荐从源代码进行,以保持最新版本。用户可以通过以下方法安装 Medusa:
- 方法 1:使用 pip 安装(可能不是最新版本)
pip install medusa-llm
- 方法 2:从源代码安装(推荐)
git clone https://github.com/FasterDecoding/Medusa.git cd Medusa pip install -e .
此外,Medusa 还提供了一些模型权重以及在单 GPU 上进行批量大小为1的推理支持。用户可根据需求通过命令行接口来进行推理,支持8bit或4bit的量化加载。
示例与贡献
Medusa 也被许多开源项目采用,如 NVIDIA 的 TensorRT-LLM、Hugging Face 的 TGI,以及阿里巴巴的 RTP-LLM。项目团队期待与更多的社区成员分享经验,欢迎有兴趣的开发者通过 GitHub 进行讨论和贡献。项目的路线图中列出了未来的计划,欢迎大家参与。
Medusa 受到了众多项目和机构支持,如 Together AI、MyShell AI 和 Chai AI,帮助加速了 LLM 的发展。Medusa 希望能为更多的研究者和开发者提供帮助,加速大语言模型的开发。