LoftQ: 支持LoRA微调的量化方法
LoftQ帮助您使用有限的GPU资源微调大型语言模型。🚀 LoftQ可以找到足够好的量化LoRA初始化:给定预训练权重W,得到量化的主干网络Q和LoRA适配器A和B。
本仓库实现了论文🔗:LoftQ: 支持大型语言模型LoRA微调的量化方法。
我们的模型可在🤗 LoftQ Huggingface Hub上获取
新闻
-
[2024/04/20] 在GSM8K上的新
LLAMA-3-8B
结果。查看结果在此。查看🦙 LLAMA-3、CodeLLAMA-7b、CodeLLAMA=13b的LoftQ在Huggingface Hub上。 -
[2024/04/13] 在GSM8K上的新
phi-2
结果。查看结果在此。查看Phi-2的LoftQ在Huggingface Hub上。 -
[2024/04/13] 更新
script/train_gsm8k.sh
以支持量化模型的数据并行。
快速开始
要求
我们使用bitsandbytes来实现量化。 该包仅支持CUDA >= 11.0,不支持CPU。 但是,如果GPU资源充足,我们也提供假量化以实现快速并行训练。
pip install -r requirements.txt
步骤
- 将LoftQ应用于全精度预训练权重并保存。
- 加载LoftQ初始化并训练。
对于步骤1,我们已在Huggingface Hub LoftQ中提供了现成的LoftQ初始化(参见支持的模型列表)。 如果您想自己操作,请跳转至LoftQ DIY。
对于步骤2,以下是从Huggingface Hub加载4位Mistral-7B和64秩LoRA适配器的示例。
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
# 在https://huggingface.co/LoftQ获取MODEL_ID
MODEL_ID = "LoftQ/Mistral-7B-v0.1-4bit-64rank"
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16, # 对于不同的模型可能需要更改
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16, # 推荐使用bfloat16
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4',
),
)
peft_model = PeftModel.from_pretrained(
base_model,
MODEL_ID,
subfolder="loftq_init",
is_trainable=True,
)
# 使用peft_model进行训练 ...
LoftQ DIY
应用LoftQ并保存
我们提供了quantize_save.py作为示例,用于应用具有不同位数(--bits
)、秩(--rank
)和交替步骤(--iter
,LoftQ中的一个超参数,参见LoftQ论文中的算法1)的LoftQ。目前,此示例支持
llama-2
、falcon
、mistral
、bart
、t5
、deberta
、bert
、roberta
。
以下是通过5个交替步骤获取4位LLAMA-2-7b和16秩LoRA适配器的示例。
SAVE_DIR="model_zoo/loftq/"
python quantize_save_load.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \ # HF中的高精度模型ID
--token HF_TOKEN \ # 如果模型是私有的(如llama-2),则需要您的HF令牌
--bits 4 \
--iter 5 \
--rank 16 \
--save_dir $SAVE_DIR
上述命令最终会在$SAVE_DIR
下创建模型目录。
具体来说,模型目录的命名方式为
MODEL_DIR = SAVE_DIR + f"{args.model_name_or_path.split('/')[-1]}-{args.bits}bits-{args.rank}rank"
在这个例子中,MODEL_DIR="model_zoo/loftq/Llama-2-7b-hf-4bit-16rank"
,其中主干网络存储在$MODEL_DIR
中,
LoRA适配器位于子文件夹$MODEL_DIR/loftq_init
中。
加载和训练
与从Huggingface Hub加载类似,我们只需将MODEL_ID
更改为MODEL_DIR
。
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
MODEL_DIR = "model_zoo/loftq/Llama-2-7b-hf-4bit-16rank"
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_DIR,
torch_dtype=torch.bfloat16,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4',
),
)
peft_model = PeftModel.from_pretrained(
base_model,
MODEL_DIR,
subfolder="loftq_init",
is_trainable=True,
)
# 使用peft_model进行训练 ...
LoftQ微调
我们还提供了一个在GSM8K上使用LoftQ微调LLAMA-7b的示例。
python train_gsm8k.py \
--model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \
--learning_rate 3e-4 \
--seed 11 \
--expt_name gsm8k_llama2_7b_4bit_64rank_loftq \
--output_dir exp_results/ \
--num_train_epochs 6 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "epoch" \
--weight_decay 0.1 \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 10 \
--do_train \
--report_to tensorboard
其他训练文件
- GLUE:
glue/run_glue.py
- 问答:
glue/run_qa.py
- 摘要:
train_summarization.py
- WikiText-2:
train_clm.py
- GSM8K:
train_gsm8k.py
更多示例脚本在scripts中。
快速评估
以下是使用我们已微调的适配器测试GSM8K的命令。它存储在LoftQ Huggingface hub中目标模型的subfolder='gsm8k'
中。
python test_gsm8k.py \
--model_name_or_path LoftQ/Llama-2-7b-hf-4bit-64rank \
--batch_size 16
python test_gsm8k.py \
--model_name_or_path LoftQ/phi-2-4bit-64rank \
--batch_size 16
您可以根据机器情况自由调整batch_size
。
主要结果
LLAMA-2在WikiText-2和GSM8K上的结果
位数 | WikiText-2 | WikiText-2 | GSM8K | GSM8K |
---|---|---|---|---|
LLAMA-2-7b | LLAMA-2-13b | LLAMA-2-7b | LLAMA-2-13b | |
16 | 5.08 | 5.12 | 36.9 | 43.1 |
4 | 5.24 | 5.16 | 35.0 | 45.0 |
3 | 5.63 | 5.13 | 32.9 | 44.4 |
2.5 | 5.78 | 5.22 | 31.1 | 41.1 |
2.25 | 6.13 | 5.45 | 26.5 | 38.1 |
2 | 7.85 | 7.69 | 20.9 | 25.4 |
模型通过在训练集上进行因果语言建模微调,并在验证/测试集上进行测试。
Phi-2在GSM8K上的结果
模型 | 位数 | 秩 | LoRA初始化 | GSM8K |
---|---|---|---|---|
Phi-2 | 16 | - | 全模型微调 | 66.8±1.2 |
Phi-2 | 16 | 64 | 高斯分布 + 0 | 64.8±0.5 |
Phi-2 | 4 | 64 | 高斯分布 + 0 (QLoRA) | 60.2±0.6 |
Phi-2 | 4 | 64 | LoftQ | 64.1±0.7 |
LLAMA-3在GSM8K上的结果
模型 | 位数 | 秩 | LoRA初始化 | GSM8K |
---|---|---|---|---|
LLAMA-3-8B | 16 | - | 全模型微调 | 70.4±0.7 |
LLAMA-3-8B | 16 | 64 | 高斯分布 + 0 (LoRA) | 69.3±1.5 |
LLAMA-3-8B | 4 | 64 | 高斯分布 + 0 (QLoRA) | 67.4±1.0 |
LLAMA-3-8B | 4 | 64 | LoftQ | 68.0±0.6 |
模型通过在(重新格式化的)训练集上进行因果语言建模微调,并在验证/测试集上进行测试。
BART-large在CNN/DailyMail和XSum上的结果
位数 | 秩 | XSum | CNN/DailyMail |
---|---|---|---|
Lead-3* | 16.30/1.60/11.95 | 40.42/17.62/36.67 | |
16 | 16 | 43.95/20.72/35.68 | 45.03/21.84/42.15 |
4 | 16 | 44.51/21.14/36.18 | 43.96/21.06/40.96 |
2 | 16 | 40.81/17.85/32.80 | 42.52/19.81/39.51 |
16 | 8 | 43.40/20.20/35.20 | 44.72/21.58/41.84 |
4 | 8 | 44.08/20.72/35.89 | 43.81/20.95/40.84 |
2 | 8 | 39.63/16.65/31.62 | 42.24/19.44/29.04 |
*: 使用文档中的前3个句子作为摘要
DeBERTa-V3-base在GLUE上使用普通浮点数据类型的结果
位数 | 秩 | MNLI | QNLI | RTE | SST | MRPC | CoLA | QQP | STSB | SQuAD | ANLI |
---|---|---|---|---|---|---|---|---|---|---|---|
m / mm | 准确率 | 准确率 | 准确率 | 准确率 | 准确率 | Mcc | P/S 相关 | EM/F1 | 准确率 | ||
16 | 16 | 90.5/90.6 | 94.0 | 82.0 | 95.3 | 89.5/93.3 | 69.2 | 92.4/89.8 | 91.6/91.1 | 88.5/92.8 | 59.8 |
2 | 16 | 84.7/85.1 | 86.6 | 61.4 | 90.2 | 83.8/88.6 | 37.4 | 90.3/86.9 | 87.1/86.9 | 81.5/88.6 | 47.1 |
2 | 32 | 86.0/86.1 | 89.9 | 61.7 | 92.0 | 83.6/87.2 | 47.5 | 91.0/87.9 | 87.5/87.0 | 82.9/89.8 | 49.0 |
DeBERTa-V3-base在GLUE上使用均匀量化数据类型的结果
位数 | 秩 | MNLI | QNLI | RTE | SST | MRPC | CoLA | QQP | STSB | SQuAD |
---|---|---|---|---|---|---|---|---|---|---|
m / mm | 准确率 | 准确率 | 准确率 | 准确率 | 准确率 | Mcc | P/S 相关 | Em/F1 | ||
16 | 16 | 90.5/90.6 | 94.0 | 82.0 | 95.3 | 89.5/93.3 | 69.2 | 92.4/89.8 | 91.6/91.1 | 88.5/92.8 |
2 | 16 | 87.3/87.1 | 90.6 | 61.1 | 94.0 | 87.0/90.6 | 59.1 | 90.9/88.0 | 87.9/87.6 | 84.4/91.2 |
2 | 32 | 88.0/88.1 | 92.2 | 63.2 | 94.7 | 87.5/91.2 | 60.5 | 91.3/88.3 | 89.5/89.2 | 85.2/91.6 |
引用
@article{li2023loftq,
title={Loftq: Lora-fine-tuning-aware quantization for large language models},
author={Li, Yixiao and Yu, Yifan and Liang, Chen and He, Pengcheng and Karampatziakis, Nikos and Chen, Weizhu and Zhao, Tuo},
journal={arXiv preprint arXiv:2310.08659},
year={2023}
}
附录:现成模型列表
模型名称 | 位数 | 秩 |
---|---|---|
LLAMA-3-8B | 4 | 64 |
CodeLLAMA-7b | 4 | 64 |
CodeLLAMA-13b | 4 | 64 |
Phi-2 | 4 | 64 |
LLAMA-2-7b | 4 | 64 |
LLAMA-2-13b | 4 | 64 |
LLAMA-2-70b | 4 | 64 |
Mistral | 4 | 64 |
Mistral | 4 | 32 |
BART-large | 4 | 8 |
BART-large | 4 | 16 |
BART-large | 4 | 32 |
BART-large | 2 | 8 |