项目介绍:Distill-SD
概述
Distill-SD 项目通过知识蒸馏技术,打造了更小更快的 Stable Diffusion 模型。这些缩小版的模型能够在保持接近质量的前提下,显著提高图像生成的速度,并优化存储空间使用。此项目的实施基于 BK-SDM 的研究论文中的理论。
项目组成
- data.py:包含下载训练数据的脚本。
- distill_training.py:用于根据研究论文中的方法训练 U-net,用户可以调整配置以训练不同尺寸的模型(如 sd_small 或 sd_tiny)、批次大小及其他超参数。这部分代码源于 Huggingface 的 diffusers 库。
- LoRA 训练与从检查点开始的训练可以通过标准的 diffusers 脚本完成。
训练详情
知识蒸馏训练类似于一个大模型(教师模型)和小模型(学生模型)的学习过程。教师模型利用大规模数据进行训练,而学生模型在较小规模数据集上进行训练,目标是模仿教师模型的输出。Distill-SD 使用了 SG161222/Realistic_Vision_V4.0 的 U-net 模型作为教师模型,并从 recastai/LAION-art-EN-improved-captions 数据集中选择了一部分数据进行训练。
最终的训练损失包含:教师模型和学生模型之间的噪声预测 MSE 损失、实际噪声和预测噪声之间的任务损失、以及 U-net 中每个块后的预测损失总和。
参数情况
- 常规 Stable Diffusion U-net 参数:859,520,964
- SD_Small U-net 参数:579,384,964
- SD_Tiny U-net 参数:323,384,964
使用方法
可以通过 Python 脚本在 GPU 上运行图像生成。用户需要提供提示语和负面提示语,以及适当的模型路径等配置参数。
训练模型
Distill-SD 的训练脚本类似于 diffusers 的文本到图像微调脚本,但增加了一些参数,例如:
--distill_level
:指定使用的模型类型,如 "sd_small" 或 "sd_tiny"。--output_weight
和--feature-weight
:用于缩放输出和特征级别的知识蒸馏损失。
速度优势
相较于完整版本,Distill-SD 模型可实现高达 100% 的推断速度提升和 30% 的显存使用减少,其训练和 LORA 训练也具备加速性能。
局限性
虽然 Distill-SD 模型在许多方面都已经很优秀,但其效果尚未达到生产质量水平。在多概念或组合性上仍有进一步提高的空间。
研究路线图
团队计划在未来开发 SDXL 蒸馏模型、进一步优化 SD-1.5 基础模型、应用 Flash Attention-2 和 TensorRT 等技术以提高模型的训练和推断速度。
致谢
项目得到了 Nota AI 在模型压缩领域研究的启发和支持。在此向他们的研究工作表达谢意。
通过以上的改进和计划,Distill-SD 项目继续致力于将模型压缩及快速高效的图像生成推向新高。