TinyLlama项目旨在预训练一个1.1B参数的Llama模型,训练数据量为3万亿个token。通过一些适当的优化,我们可以在"仅仅"90天内使用16个A100-40G GPU完成这一目标🚀🚀。训练已于2023年9月1日开始。
我们采用了与Llama 2完全相同的架构和分词器。这意味着TinyLlama可以即插即用于许多基于Llama的开源项目。此外,TinyLlama仅有1.1B参数,非常紧凑。这种紧凑性使其能够满足许多对计算和内存占用有严格限制的应用需求。
新闻
- 2023-12-18:添加了两个说明 1,2 解释了训练曲线的变化、项目进度和错误修复。
- 2023-10-03:添加了使用llama.cpp进行推测解码的示例。请查看 speculative_decoding/README.md。
- 2023-10-02:1. 刚刚发布了1T token的检查点。2. 我们在这里记录了所有中间检查点。
- 2023-09-28:添加了Discord服务器。
- 2023-09-18:1. 我们添加了一个聊天演示,让您可以立即试用TinyLlama-Chat-V0.1。
- 2023-09-16:1. 我们发布了在503B token上训练的中间检查点。2. 我们发布了一个在OpenAssistant上微调的聊天模型,并添加了简单的微调脚本。3. 在EVAL.md中添加并记录了更多评估基准。
评估
您可以在EVAL.md中找到TinyLlama的评估结果。
发布计划
我们将按照以下计划发布中间检查点。
基础模型:
日期 | HF 检查点 | Token数 | 步骤 | 常识平均分 |
---|---|---|---|---|
2023-09-01 | Pythia-1.0B | 300B | 143k | 48.30 |
2023-09-04 | TinyLlama-1.1B-intermediate-step-50k-105b | 105B | 50k | 46.11 |
2023-09-16 | TinyLlama-1.1B-intermediate-step-240k-503b | 503B | 240K | 48.28 |
2023-10-01 | TinyLlama-1.1B-intermediate-step-480k-1T | 1T | 480k | 50.22 |
2023-11-04 | TinyLlama-1.1B-intermediate-step-715k-1.5T | 1.5T | 715k | 51.28 |
2023-11-20 | TinyLlama-1.1B-intermediate-step-955k-2T | 2T | 955k | 51.64 |
2023-12-11 | TinyLlama-1.1B-intermediate-step-1195k-2.5T | 2.5T | 1195k | 53.86 |
2023-12-28 | TinyLlama-1.1B-intermediate-step-1431k-3T | 3T | 1431k | 52.99 |
我们正在撰写一份说明,提供从2T到2.5T检查点显著改进的可能解释(这与bos_id问题有关)
聊天模型:
日期 | HF 检查点 | Token数 | 步骤 | 常识平均分 |
---|---|---|---|---|
2023-09-16 | TinyLlama-1.1B-Chat-V0.1 | 503B | 240K | 49.57 |
2023-10-1 | TinyLlama-1.1B-Chat-V0.3 | 1T | 480K | 51.36 |
2023-11-04 | TinyLlama-1.1B-Chat-V0.4 | 1.5T | 715K | 52.30 |
请注意,基础模型的学习率尚未降低,因此我们建议您也使用微调后的聊天模型。
同时,您可以在这里实时跟踪交叉熵损失。
潜在用途
小而强大的语言模型在许多应用中都很有用。以下是一些潜在用途:
- 辅助更大模型的推测解码。(参见Andrej Karpathy的教程)
- 部署在内存和计算能力受限的边缘设备上,用于诸如无需互联网连接的实时机器翻译等功能(4位量化的TinyLlama-1.1B的权重仅占用637 MB)。
- 在视频游戏中实现实时对话生成。
此外,我们的代码可以作为对预训练5亿以下参数语言模型感兴趣的爱好者的参考,无需过早深入研究Megatron-LM。
训练细节
以下是我们训练设置的一些详细信息:
设置 | 描述 |
---|---|
参数 | 1.1B |
注意力变体 | 分组查询注意力 |
模型大小 | 层数: 22, 头数: 32, 查询组数: 4, 嵌入大小: 2048, 中间大小 (Swiglu): 5632 |
序列长度 | 2048 |
批量大小 | 200万个token (2048 * 1024) |
学习率 | 4e-4 |
学习率调度 | 余弦退火,2000步预热。参见Issue 27了解一个小bug |
训练数据 | Slimpajama & Starcoderdata |
数据预处理 | 排除了Slimpajama的GitHub子集;从Starcoderdata中抽样所有代码 |
合并数据集大小 | 约950B个token |
训练期间的总token数 | 3万亿(略多于3个周期/1430k步) |
自然语言与代码比例 | 7:3 |
硬件 | 16个A100-40G GPU |
极速快
我们的代码库支持以下功能:
- 使用FSDP进行多GPU和多节点分布式训练。
- flash attention 2。
- 融合层归一化。
- 融合swiglu。
- 融合交叉熵损失。
- 融合旋转位置编码。
致谢:flash attention 2、融合层归一化、融合交叉熵损失和融合旋转位置编码来自FlashAttention仓库。融合swiglu来自xformers。
得益于这些优化,我们在每个A100-40G GPU上实现了每秒24k个token的吞吐量,这相当于56%的模型浮点运算利用率,且不使用激活检查点(我们预计在A100-80G上MFU会更高)。这意味着您可以在32小时内用8个A100训练一个符合chinchilla最优的TinyLlama(1.1B参数,22B token)。这些优化还大大减少了内存占用,让我们能够将1.1B参数的模型塞进40GB的GPU内存,并以每GPU 16k token的批量大小进行训练。您也可以在3090/4090 GPU上预训练TinyLlama,只需使用更小的每GPU批量大小。 下面是我们代码库与Pythia和MPT训练速度的比较。
模型 | 在300B token上花费的A100 GPU小时数 |
---|---|
TinyLlama-1.1B | 3456 |
Pythia-1.0B | 4830 |
MPT-1.3B | 7920 |
Pythia的数据来自他们的论文。MPT的数据来自这里,其中他们说MPT-1.3B"在440个A100-40GB上训练了大约半天",处理了200B个token。
TinyLlama是一个相对较小的模型,使用分组查询注意力,这意味着它在推理时也很快。以下是我们测量的一些吞吐量:
预训练
有关如何预训练TinyLlama的说明,请参阅PRETRAIN.md。
微调
我们在 sft 中包含了一个简单的全参数微调和推理脚本。我们的 V0.1 聊天模型就是使用这个脚本进行微调的。我们使用的微调数据集是 openassistant-guanaco。 对于 RAM 小于 4GB 的微调,我们建议您参考 Qlora 和 bitsandbytes 仓库。 我们没有进行广泛的超参数调优,也没有选择更高性能的微调数据集。我们希望社区能够探索 TinyLlama 的微调,并开发出更好的聊天模型。我将在此仓库中包含社区微调的模型。
待办事项
本项目仍在积极开发中。我们是一个非常小的团队。非常感谢社区的反馈和贡献。以下是我们计划进行的一些工作:
- 添加在其他数据集上预训练的脚本。
- 序列长度外推。
- 测试 Llama-2-7B 的推测性解码。
- 测试 RTX 3090/4090 的吞吐量。
- 添加微调脚本。
- 对下游任务进行适当的模型评估。
- 在手机上运行的演示。
- 探索检索增强。
致谢
本仓库基于 lit-gpt 和 flash-attention 构建。如果您还不了解这些优秀的开源项目,一定要去探索一下!
@online{lit-gpt,
author = {Lightning AI},
title = {Lit-GPT},
url = {https://github.com/Lightning-AI/lit-gpt},
year = {2023},
}
@article{dao2023flashattention2,
title ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author ={Dao, Tri},
year ={2023}
}
引用
本项目目前由新加坡科技设计大学 StatNLP 研究组的 Peiyuan Zhang *、Guangtao Zeng *、Tianduo Wang 和 Wei Lu 共同贡献。
如果您认为我们的工作有价值,请引用:
@misc{zhang2024tinyllama,
title={TinyLlama: An Open-Source Small Language Model},
author={Peiyuan Zhang and Guangtao Zeng and Tianduo Wang and Wei Lu},
year={2024},
eprint={2401.02385},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
常见问题
1. 为什么对一个 1.1B 的模型进行如此长时间的预训练是有意义的?这是否与 Chinchilla 缩放定律相矛盾?
上图是 Llama 2 论文中的训练损失曲线。这里我引用该论文的一段话:"我们观察到,在 2T 个 Token 的预训练之后,模型仍然没有显示出任何饱和的迹象"。这就是为什么我们认为对一个 1.1B 的模型进行 3T 个 Token 的预训练是合理的。即使损失曲线最终不会下降,我们仍然可以研究饱和现象并从中学到一些东西。
2. "饱和"是什么意思?
这张图来自 Pythia 论文,显示了 LAMBADA 准确率随总训练 Token 数(300B)的变化。"饱和"一词特指 70M 和 160M 模型。值得注意的是,即使是 410M 模型在 300B Token 时也没有饱和,它仍然呈现上升趋势,类似于更大模型的趋势。