TriForce: 使用分层推测解码实现长序列生成的无损加速
无需训练,加速长序列生成
环境设置
conda create -n TriForce python=3.9
conda activate TriForce
pip install -r requirements.txt
pip install flash-attn --no-build-isolation # 安装flash-attn
评估
目前仅支持长上下文Llama模型(包括Llama2-7B-128K、Llama2-13B-128K、LWM-Text-128K、LWM-Text-Chat-128K)。
片上
通过运行以下命令,可以在A100上复现片上结果。--prefill
指定提示的上下文长度,--budget
指定检索缓存的预算。chunk_size
指定KV缓存的块大小。top_p
和temp
是采样超参数,默认设置为0.9和0.6。gamma
是推测解码步数。在单个A100上运行以下命令,您应该能观察到2.2倍的加速。gs
包含来自PG-19的20个样本,128k
包含128K个样本,lwm
包含来自NarrativeQA的样本。
# TriForce,在A100上
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 124928 --budget 4096 \
--chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
卸载
使用张量并行的卸载
我们的框架支持卸载设置的张量并行。--nproc_per_node
应设置为用于卸载的GPU数量。以下命令演示了如何使用2个GPU进行张量并行。需要注意的是,RTX 4090不支持张量并行的CUDA Graph(而A100支持)。因此,我们在此设置中禁用了CUDA Graph。--on_chip
指定片上KV缓存的层数,可以根据硬件进行调整。卸载的性能在很大程度上取决于PCIE的带宽。为了获得准确的结果,最好确保带宽不被其他程序使用。
# TriForce,在2张RTX 4090 GPU上
CUDA_VISIBLE_DEVICES=0,1 OMP_NUM_THREADS=48 torchrun --nproc_per_node=2 \
test/offloading_TP.py --budget 12288 --prefill 130048 --dataset gs \
--target llama-7B-128K --on_chip 9 --gamma 16
不使用张量并行的卸载
我们建议使用2张RTX 4090进行卸载,因为编码时间更短,生成延迟更低。但如果您只有1张RTX 4090,仍可以运行以下命令。由于预算较小,平均接受的令牌长度会更短。
# TriForce,CUDA Graph
# Huggingface后端,cuda graph可能会占用一些额外的HBM
CUDA_VISIBLE_DEVICES=0 python test/offloading.py --prefill 130048 \
--chunk_size 8 --temp 0.6 --top_p 0.9 --gamma 12 --dataset gs \
--budget 8192 --target llama-7B-128K
# TriForce,计算和加载重叠
# 重叠可能会占用一些额外的HBM
CUDA_VISIBLE_DEVICES=0,1 OMP_NUM_THREADS=48 torchrun --nproc_per_node=1 \
test/offloading_TP.py --budget 8192 --prefill 130048 --dataset gs \
--target llama-7B-128K --on_chip 0 --gamma 12
基准
对于卸载,我们提供了自回归基准的实现以供比较。如果TriForce的性能不符合预期,这可能是由于PCIE带宽低,我们建议在相同的硬件上评估基准的性能。为了演示如何在不同的硬件配置上执行基准,这里是在两张RTX 4090 GPU上和单独一张RTX 4090 GPU上运行基准的命令。
# 基准,2张RTX 4090
CUDA_VISIBLE_DEVICES=0,1 OMP_NUM_THREADS=48 torchrun --nproc_per_node=2 \
test/offloading_TP.py --budget 0 --prefill 130048 --dataset demo \
--target lwm-128K --on_chip 12 --baseline
# 基准,1张RTX 4090
CUDA_VISIBLE_DEVICES=0,1 OMP_NUM_THREADS=48 torchrun --nproc_per_node=1 \
test/offloading_TP.py --budget 0 --prefill 130048 --dataset demo \
--target lwm-128K --on_chip 2 --baseline
引用
如果您发现TriForce对您的项目和研究有用或相关,请引用我们的论文:
@article{sun2024triforce,
title={Triforce: Lossless acceleration of long sequence generation with hierarchical speculative decoding},
author={Sun, Hanshi and Chen, Zhuoming and Yang, Xinyu and Tian, Yuandong and Chen, Beidi},
journal={arXiv preprint arXiv:2404.11912},
year={2024}
}
常见问题
-
环境问题
确保您使用的是
transformers==4.37.2
,因为在更新版本的transformers中apply_rotary_pos_emb
API发生了变化。此外,一些环境问题(例如与最新flash-attn的不兼容)可以通过设置torch==2.2.1+cu121
和flash_attn==2.5.7
来解决。更多详情,请参考issue #7。