Medusa: 用于加速LLM生成的多解码头简单框架
新闻 🔥
- [2024/1] Medusa技术报告现已在arXiv发布。我们添加了多个新功能,包括用于全模型训练的Medusa-2配方、自蒸馏以将Medusa添加到任何微调的LLM等。新结果显示在一系列LLM上相比原始模型有2.2-3.6倍的加速。
介绍
Medusa是一个简单的框架,使多解码头的加速技术在LLM生成中更具民主化。
我们旨在解决流行加速技术如推测解码的三个痛点:
- 需要一个好的草稿模型。
- 系统复杂性。
- 使用基于采样的生成时的低效性。
我们通过以下思路解决与推测解码相关的挑战:
- 不是引入一个新模型,而是在同一个模型上训练多个解码头。
- 训练是参数高效的,即使是“GPU匮乏”也能进行。而且由于没有额外的模型,不需要调整分布计算设置。
- 放宽与原始模型匹配分布的要求,使非贪婪生成比贪婪解码更快。
在初始版本中,我们主要关注的是优化单批大小为1的Medusa——这是本地模型托管常用的设置。在这种配置下,Medusa在一系列Vicuna模型上提供了约2倍的速度提升。我们正积极致力于将Medusa集成到更多的推理框架中,以期实现更大的性能提升,并扩展到更广泛的设置。
在更新版本中,我们增加了对全模型训练的支持,称为Medusa-2(相比于只训练新头部的Medusa-1),它需要一个特殊的配方来增加推测预测的能力,同时保持原始模型的性能。
我们还增加了对自蒸馏的支持,使我们可以将Medusa添加到任何微调的LLM中,而不需要原始训练数据的可用性。
目录
安装
方法1:使用pip(可能不是最新版本)
pip install medusa-llm
方法2:从源码安装(推荐)
git clone https://github.com/FasterDecoding/Medusa.git
cd Medusa
pip install -e .
模型权重
Medusa-1
大小 | 聊天命令 | Hugging Face仓库 |
---|---|---|
7B | python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-7b-v1.3 | FasterDecoding/medusa-vicuna-7b-v1.3 |
13B | python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-13b-v1.3 | FasterDecoding/medusa-vicuna-13b-v1.3 |
33B | python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-33b-v1.3 | FasterDecoding/medusa-vicuna-33b-v1.3 |
Medusa-2
大小 | 聊天命令 | Hugging Face仓库 |
---|---|---|
Zephyr-7B-Beta | python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-zephyr-7b-beta | FasterDecoding/medusa-1.0-zephyr-7b-beta |
Vicuna-7B-v1.5 | python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-7b-v1.5 | FasterDecoding/medusa-1.0-vicuna-7b-v1.5 |
Vicuna-13B-v1.5 | python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-13b-v1.5 | FasterDecoding/medusa-1.0-vicuna-13b-v1.5 |
Vicuna-33B-v1.5 | python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-33b-v1.5 | FasterDecoding/medusa-1.0-vicuna-33b-v1.5 |
推理
我们目前支持单GPU推理,批大小为1,这是本地模型托管最常见的设置。我们正积极致力于将Medusa集成到其他推理框架中;如果您有兴趣参与此工作,请随时联系我们。
您可以使用以下命令启动CLI界面:
CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli --model [Medusa模型路径]
您也可以通过传递 --load-in-8bit
或 --load-in-4bit
将基础模型加载为量化格式。如果您在其他地方下载了基础模型,您可以使用--base-model [基础模型路径]
覆盖基础模型名称或路径。
训练
在更新版本中,我们使用了惊人的axolotl库来管理训练过程。请参阅我们的fork的训练代码。主要代码修改在src/axolotl/utils/models.py
。训练配置可以在examples/medusa
中找到。一个典型的训练命令如下:
accelerate launch -m axolotl.cli.train examples/medusa/your_config.yml
自蒸馏的数据准备代码可以在当前仓库的data_generation
文件夹中找到。对于其他数据集,您可以直接从相应的Hugging Face数据集仓库下载数据。
在各种架构上训练
以下说明适用于Medusa的初始发布版本,它提供了如何训练Medusa-1模型的最小示例。有关更新版本,请参阅上一节。
要进行训练,请安装:
pip install -e ".[train]"
准备数据
我们使用公开版本的ShareGPT数据集,这是Vicuna训练数据的一个子集。对于其他模型,您可以使用相应的训练数据集。
git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered
备注:如果您尚未安装git-lfs
,请在克隆之前安装它:
git lfs install
将数据调整为您希望启用Medusa的模型。
首先启动您喜欢的推理服务器,该服务器将运行您要训练的模型。我们以mistralai/Mistral-7B-Instruct-v0.2为例。
例如,您可以使用text-generation-inference,在训练完Medusa头部后也可以使用它。
model=mistralai/Mistral-7B-Instruct-v0.2
volume=$PWD/data # 与Docker容器共享一个卷以避免每次运行时下载权重
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --input-length 4000 --max-total-tokens 4096 --max-batch-prefill-tokens 4000
ShareGPT中的一些序列相对较长,所以请确保可以对这些序列进行推理。如果空间不足,脚本将简单地忽略那些长对话。这不应对后续性能产生太大影响,但更多数据总是更好。您可以使用各种权衡来加速推理,但默认设置在大多数情况下已经足够好。
python create_data.py --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json --output-filename mistral.json
训练模型
我们遵循FastChat的训练设置,但采用更大的学习率,因为我们冻结了原始模型,只训练新的头部。下面是4个GPU上训练Vicuna-7b模型的训练命令。由于我们只训练新的头部,训练不需要很多内存,只需要数据并行。您可以调整脚本以适应您的设置。对于较大的模型,我们使用相同的设置。您还可以使用--load_in_8bit
或--load_in_4bit
以量化格式加载基础模型。
torchrun --nproc_per_node=4 medusa/train/train_legacy.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
--data_path mistral.json \
--bf16 True \
--output_dir test \
--num_train_epochs 2 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "no" \
--save_strategy "no" \
--learning_rate 1e-3 \
--weight_decay 0.0 \
--warmup_ratio 0.1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--lazy_preprocess True \
--medusa_num_heads 3 \
--medusa_num_layers 1 \
--deepspeed deepspeed.json
推送到Hugging Face Hub
您可以使用以下命令将模型推送到Hugging Face Hub:
python -m medusa.hf_utils --folder [模型文件夹路径] --repo [仓库名称]
引用
@article{cai2024medusa,
title = {Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads},
author = {Tianle Cai and Yuhong Li and Zhengyang Geng and Hongwu Peng and Jason D. Lee and Deming Chen and Tri Dao},
year = {2024},
journal = {arXiv preprint arXiv: 2401.10774}
}
代码库指南
medusa/model/medusa_model.py
是Medusa的关键文件。它包含MedusaModel
类,这是原始模型和新头部的包装器。该类还实现了一个流式生成方法。如果您想深入了解Medusa的细节,这里是一个好的起点。
我们还在notebooks/
中提供了一些说明性笔记本,以帮助您理解代码库。
社区采用
我们很高兴看到许多开源项目已经采用了Medusa。以下是一个(不完整的)列表:
我们感谢作者们对社区的贡献,并真诚希望Medusa能帮助加速LLM的发展。如果您在项目中使用了Medusa,请告诉我们,我们将把您的项目添加到列表中。
贡献
我们欢迎社区对Medusa的贡献。如果您有改进的想法,请开启一个问题与我们讨论。在提交拉取请求时,请确保您的改动已经过良好测试。请将每一个重大变更分为一个单独的拉取请求。我们还拥有一个路线图,总结了我们对Medusa的未来计划。如果您对任何路线图上的条目感兴趣,请随时联系我们。
致谢
这个代码库受到了LLM社区一些出色项目的影响,包括FastChat、TinyChat、vllm、axolotl。
这个项目得到了Together AI、MyShell AI、Chai AI的支持。