Project Icon

deita

自动数据选择工具助力大语言模型指令调优

Deita是一个开源项目,为大型语言模型的指令调优提供自动数据选择工具。项目包含开源工具包、高质量轻量级数据集和高效训练模型。Deita模型使用仅十分之一的指令调优数据,就能达到其他先进聊天模型的性能水平。项目提供全面评估结果,展示了在多项基准测试中的表现。

Deita

🤗 HF仓库    📄 论文    📚 6K数据集    📚 10K数据集

欢迎来到Deita(Data-Efficient Instruction Tuning for Alignment,数据高效指令微调对齐)项目!

我们将持续更新,敬请关注!

Deita是什么?

Deita是一个开源项目,旨在为大语言模型(LLMs)的指令微调提供自动数据选择

它包括:

  • 用于指令微调自动数据选择的开源工具包
  • Deita数据集:一系列极其轻量级、高质量的对齐SFT数据。我们在首次发布中提供了6k规模和10k规模的数据集
  • Deita模型:一系列通过极其高效的指令微调过程达到与最先进聊天大语言模型相当水平的强大模型。与其他最先进的大语言模型相比,Deita模型只需使用十分之一的指令微调数据即可训练得到

新闻

性能

:bell: 还在好奇少量高质量数据能让大语言模型达到多远吗?

Deita可能为您提供一个答案:

🔦 亮点

模型对齐方式数据规模MT-BenchAlpacaEval(%)
Zephyr-7B-sftSFT200K5.3275.12
$\text{Zephyr-7B-}\beta$SFT + DPO200K SFT + 60K DPO7.3490.60
OpenChat-3.5C-RLFT>> 70K C-RLFT7.8188.51
Starling-7BC-RLFT + APA>> 70K C-RLFT + 183K APA8.0991.99
Tulu-2-13BSFT326K6.7078.90
Tulu-2-13B+DPOSFT + DPO326K SFT + 60K DPO7.0089.50
LLaMA2-13B-ChatSFT + PPO--6.6581.09
WizardLM-13B-v1.2SFT>70K7.0989.17
Vicuna-13B-v1.5SFT>125K6.5778.80
DEITA-7B-v1.0 (6K)SFT6K7.2280.78
DEITA-7B-v1.0-sftSFT10K7.3281.67
DEITA-7B-v1.0SFT + DPO6K SFT + 10K DPO7.5590.06

DEITA模型基于Mistral-7B-v0.1。:fire:

完整评估请参见此表,其中包括Open LLM排行榜以及基于LLaMA基础模型的DEITA模型和与其他数据选择方法的比较。

:chart_with_upwards_trend: 完整评估

查看完整评估 | 模型 | 对齐方式 | 数据规模 | MT-Bench评分 | AlpacaEval(%) | OpenLLM (平均) | |------------------------------------------------|-----------|------------|----------|---------------|----------------| | **专有模型** | | | | | | | GPT-4-Turbo | ? | -- | 9.32 | 97.70 | -- | | GPT-4 | SFT + PPO | -- | 8.99 | 95.03 | -- | | Claude-2 | SFT + PPO | -- | 8.06 | 91.36 | -- | | GPT-3.5-turbo | SFT + PPO | -- | 7.94 | 89.37 | -- | | **基于LLaMA-1-13B的开源模型** | | | | | | | LIMA | SFT | 1K SFT | 4.29 | 41.98 | 59.82 | | WizardLM-13B | SFT | 70K SFT | 6.35 | 75.31 | 58.96 | | Vicuna-13B-v1.3 | SFT | 125K SFT | 6.39 | 82.11 | 60.01 | | 随机 | SFT | 10K SFT | 6.03 | 71.52 | 60.14 | | DEITA-LLaMA1-13B-v1.0-sft | SFT | 10K SFT | 6.60 | 78.01 | 64.27 | | **基于LLaMA-2-13B的开源模型** | | | | | | | Tulu-2-13B | SFT | 326K SFT | 6.70 | 78.90 | -- | | Tulu-2-13B+DPO | SFT + DPO | 326K SFT + 60K DPO | 7.00 | 89.50 | -- | | LLaMA2-13B-Chat | SFT + PPO | -- | 6.65 | 81.09 | -- | | WizardLM-13B-v1.2 | SFT | >70K SFT | 7.09 | 89.17 | -- | | Vicuna-13B-v1.5 | SFT | 125K SFT | 6.57 | 78.80 | 61.63 | | 随机 | SFT | 10K SFT | 5.78 | 65.19 | 61.32 | | DEITA-LLaMA2-13B-v1.0-sft | SFT | 10K SFT | 6.79 | 81.09 | 62.71 | | **基于Mistral-7B的开源模型** | | | | | | | Mistral-7B-Instruct-v0.1 | -- | -- | 6.84 | 69.65 | 60.45 | | Zephyr-7B-sft | SFT | 200K SFT | 5.32 | 75.12 | 60.93 | | $\text{Zephyr-7B-}\beta$ | SFT + DPO | 200K SFT + 60K DPO | 7.34 | 90.60 | 66.36 | | OpenChat-3.5 | C-RLFT | >> 70K C-RLFT | 7.81 | 88.51 | -- | | Starling-7B | C-RLFT + APA | >>70K C-RLFT + 183K APA | 8.09 | 91.99 | -- | | 随机 | SFT | 10K SFT | 5.89 | 56.90 | 61.72 | | DEITA-7B-v1.0-sft (6K) | SFT | 6K SFT | 7.22 | 80.78 | 64.94 | | DEITA-7B-v1.0-sft (10K) | SFT | 10K SFT | 7.32 | 81.67 | 64.00 | | DEITA-7B-v1.0 | SFT + DPO | 6K SFT + 10K DPO | 7.55 | 90.06 | 69.86 |

:rocket: Deita资源

资源链接许可证
Deita数据集
deita-6k-v0:hugs: HF仓库MIT许可证
deita-10k-v0:hugs: HF仓库MIT许可证
deita-complexity-scorer-data:hugs: HF仓库MIT许可证
deita-quality-scorer-data:hugs: HF仓库MIT许可证
deita-redundant-pool (100K):hugs: HF仓库MIT许可证
deita-sota-pool (300K):hugs: HF仓库MIT许可证
评分器
deita-complexity-scorer:hugs: HF仓库LLaMA许可证
deita-quality-scorer:hugs: HF仓库LLaMA许可证
Deita模型
DEITA-7B-v1.0-sft:hugs: HF仓库Apache-2.0
DEITA-7B-v1.0:hugs: HF仓库Apache-2.0
DEITA-LLaMA2-13B-v1.0-sft:hugs: HF仓库LLaMA 2许可证
DEITA-LLaMA1-13B-v1.0-sft:hugs: HF仓库LLaMA许可证

:running_man: 如何开始?

安装

  git clone https://github.com/hkust-nlp/deita.git
  cd deita
  pip install -e .

数据样本评分

如果你想评估单个样本回复的质量,可以按以下步骤操作:

from deita.selection.scorer import Llama_Scorer

model_name_or_path = "hkust-nlp/deita-quality-scorer"

scorer = Llama_Scorer(model_name_or_path)

# 示例输入
input_text = "描述带有有用提示的UI的词" # 示例输入
output_text = "用户友好或直观的UI" # 示例输出
quality_score = scorer.infer_quality(input_text, output_text)
print(quality_score)
# 2.0230105920381902

Deita 还支持使用 VLLM 进行更快的推理。如果你想使用 VLLM 进行推理,

pip install vllm

并在初始化评分器时设置 is_vllm = True

scorer = Llama_Scorer(model_name_or_path, is_vllm = True)

要评估数据样本的其他维度,请参考 examples/scoring

Deita 管道

你可以使用 deita 管道通过一行代码和配置对数据集执行各种操作。

  • 数据集评分
from deita.pipeline import Pipeline

pipeline = Pipeline("score_pipeline", 
                    data_path = args.data_path,   # sharegpt 格式的 json 文件
                    scorer = args.scorer,   # [mistral, llama]
                    scorer_name_or_path = args.scorer_name_or_path,  # 评分器名称或路径,例如 hkust-nlp/deita-complexity-scorer
                    is_vllm = args.is_vllm,  # 使用 vllm 启动 [True, False]
                    score_type = args.score_type, # [complexity, quality]
                    output_path = args.output_path)  # 输出路径(json 格式)

pipeline.run()
  • 获取嵌入

我们使用 Huggingface Accelerate 来提高效率:

from deita.pipeline import Pipeline

embed_pipeline = Pipeline("embed_pipeline", 
                          data_path = args.data_path,   # sharegpt 格式的 json 文件
                          output_path = args.output_path,  # 输出路径(pickle 格式)
                          model_name_or_path = args.model_name_or_path,  # 模型名称或路径,例如 mistralai/Mistral-7B-v0.1
                          max_length = args.max_length,
                          use_flash_attention = args.use_flash_attention,  
                          batch_size_per_device = args.batch_size_per_device,
                          conv_template = args.conv_template,
                          only_answer = args.only_answer,
                          random_shuffle = args.random_shuffle,
                          bfloat16 = True
                          )

embed_pipeline.run()
CUDA_VISIBLE_DEVICES=$GPUIDX accelerate launch \
    --mixed_precision bf16 \
    --num_processes $NUMPROCESS \
    --num_machines 1 \
    examples/pipelines/embed_datasets.py \
    --use_flash_attention true \
    --data_path $DATAPATH \
    --output_path $OUTPUTPATH \
    --batch_size_per_device $BSZ
  • 先评分,再多样性感知选择
from deita.pipeline import Pipeline

filter_pipeline = Pipeline("filter_pipeline", 
                          data_path = args.data_path,  # sharegpt 格式的 json 文件
                          other_data_path = args.other_data_path,  # 嵌入文件路径(pickle 格式)
                          threshold = args.threshold,  # 过滤阈值 默认: 0.9 
                          data_size = args.data_size,  # 选择的数据大小
                          chunk_size = args.chunk_size,  # 用于更高效的 GPU 计算 默认: 100000
                          sort_key = args.sort_key,  # 默认: "complexity_scores,quality_scores"
                          output_path = args.output_path,  # json 格式输出路径
                          distance_metric = args.distance_metric,  # 默认: cosine
                          embedding_field = args.embedding_field,  # 默认: embedding
                          is_compression = args.is_compression,  # 默认: False
                          device = args.device  # GPU 序号, 默认: 0
                          )

filter_pipeline.run()

你可以参考 examples/pipelines 获取更多详细信息。文档也将很快推出。

SFT 训练

请参考 examples/train/sft.sh

deepspeed --include localhost:${DEVICES} --master_port 29501 src/deita/alignment/train.py \
    --model_name_or_path ${MODELPATH} \
    --data_path ${DATAPATH} \
    --output_dir ${OUTPUTPATH}/${RUNNAME} \
    --num_train_epochs 6 \
    --per_device_train_batch_size ${BSZPERDEV} \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps ${GRADACC} \
    --eval_steps 50 \
    --save_strategy "no" \
    --save_steps 100 \
    --save_total_limit 10 \
    --learning_rate 2e-5 \
    --warmup_ratio 0.1 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --do_eval False \
    --evaluation_strategy "no" \
    --model_max_length 2048 \
    --lazy_preprocess True \
    --conv_template "vicuna_v1.1" \
    --mask_user True \
    --report_to "wandb" \
    --run_name ${RUNNAME} \
    --bf16 True \
    --deepspeed src/deita/ds_configs/deepspeed_config_zero2_no_offload.json

DPO 训练

请参考 examples/train/dpo.sh

deepspeed --include localhost:${DEVICES} --master_port 29502 src/deita/alignment/dpo_train.py \
    --model_name_or_path ${MODELPATH} \
    --json_path ${JSONPATH} \
    --data_split ${DATASPLIT} \
    --output_dir ${OUTPUTPATH}/${RUNNAME} \
    --num_train_epochs ${DPOEPOCH} \
    --beta 0.1 \
    --per_device_train_batch_size ${BSZPERDEV} \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps ${GRADACC} \
    --save_global_steps False \
    --eval_steps 50 \
    --save_strategy "no" \
    --save_steps 500 \
    --save_total_limit 1 \
    --learning_rate 5e-7 \
    --warmup_ratio 0.1 \
    --lr_scheduler_type "linear" \
    --logging_steps 1 \
    --do_eval False \
    --evaluation_strategy "no" \
    --model_max_length 2048 \
    --conv_template "vicuna_v1.1" \
    --report_to "wandb" \
    --run_name ${RUNNAME} \
    --bf16 True \
    --gradient_checkpointing True \
    --deepspeed src/deita/ds_configs/stage3_no_offloading_accelerate.json

评估

:muscle: 更多内容?

这是 Deita 项目的预览版本。我们将继续更新,包括

  • 发布带有高效实现的数据选择管道
  • 更多自动数据选择策略
  • 支持命令行界面
  • 在线演示

引用

如果你发现这个项目的内容对你有帮助,请按以下方式引用我们的论文:

@inproceedings{
liu2024what,
title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
author={Wei Liu and Weihao Zeng and Keqing He and Yong Jiang and Junxian He},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=BTKAeLqLMw}
}

致谢

对于训练代码,我们使用了 fastchat 的代码模板。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

稿定AI

稿定设计 是一个多功能的在线设计和创意平台,提供广泛的设计工具和资源,以满足不同用户的需求。从专业的图形设计师到普通用户,无论是进行图片处理、智能抠图、H5页面制作还是视频剪辑,稿定设计都能提供简单、高效的解决方案。该平台以其用户友好的界面和强大的功能集合,帮助用户轻松实现创意设计。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号