Project Icon

SpeeD

通过时间步长优化实现扩散模型训练加速

SpeeD是一种创新的扩散模型训练加速技术,通过对时间步长的深入分析和优化,将训练过程分为加速、减速和收敛三个区域。该方法采用重采样和重加权策略,实现了训练速度的显著提升。SpeeD易于与现有模型集成,能有效提高扩散模型的训练效率,为图像生成等任务提供了新的解决方案。

仔细审视时间步长可使扩散模型训练速度提升三倍

如果你喜欢SpeeD,请在GitHub上给我们一个星标⭐以获取最新更新。

论文 | 项目主页 | Hugging Face

本仓库包含了题为"仔细审视时间步长可使扩散模型训练速度提升三倍"的研究论文的代码和实现细节。在这篇论文中,我们介绍了SpeeD,一种用于加速扩散模型训练的新方法。

作者

😮 亮点

我们的方法易于兼容,可以加速扩散模型的训练。

比较

✒️ 动机

受以下对时间步长观察的启发,我们提出了重采样 + 重加权策略,如下所示。

仔细观察时间步长,我们发现时间步长可以分为三个区域:加速区、减速区和收敛区。收敛区对应时间步长的样本对训练的益处有限,而这些时间步长占用了最多的时间。从经验上看,这些样本的训练损失比其他两个区域的要低得多。

动机

非对称采样:抑制收敛区时间步长的出现。

变化感知加权:在扩散过程中变化更快的时间步长被赋予更多权重。

方法

在训练扩散模型时获取权重和采样t

def t_sample(self, n, device): if self.faster: t = torch.multinomial(self.p, n // 2 + 1, replacement=True).to(device) # 双重采样,可以平衡多步骤任务训练 dual_t = torch.where(t < self.meaningful_steps, self.meaningful_steps - t, t - self.meaningful_steps) t = torch.cat([t, dual_t], dim=0)[:n] weights = self.weights else: # 如果 t = torch.randint(0, self.num_timesteps, (n,), device=device) weights = None

return t, weights

你可以通过设置 diffusion.faster=True 来启用我们的加速模块。

# 配置文件
diffusion:
    timestep_respacing: '250'
    faster: true  #启用训练加速模块

🛠️ 要求和安装

这个代码库不使用硬件加速技术,实验环境并不复杂。

你可以创建一个新的conda环境:

conda env create -f environment.yml
conda activate speed

或者通过以下方式安装必要的包:

pip install -r requirements.txt

如有必要,我们将提供更多方法(如docker)来方便配置实验环境。

🗝️ 教程

我们提供了一个完整的生成任务流程,包括训练推理测试。目前的代码仅兼容类别条件图像生成任务。我们将在未来兼容更多关于扩散的生成任务。

我们重构了facebookresearch/DiT的代码,并使用OmegaConf加载配置。配置文件加载规则是递归的,以便更容易修改参数。简单来说,后面路径中的文件将覆盖base.yaml中的先前设置。

你可以通过修改配置文件和命令行来修改实验设置。关于配置读取的更多细节写在configs/README.md中。

对于每个实验,你必须通过命令提供两个参数,

-c: 配置路径;
-p: 阶段,包括['train', 'inference', 'sample']。

训练 & 推理

基线

使用256x256 ImageNet数据集和DiT-XL/2模型的类别条件图像生成任务。

# 训练:训练扩散模型并保存检查点
torchrun --nproc_per_node=8 main.py -c configs/image/imagenet_256/base.yaml -p train
# 推理:生成用于测试的样本
torchrun --nproc_per_node=8 main.py -c configs/image/imagenet_256/base.yaml -p inference
# 采样:为可视化采样一些图像
python main.py -c configs/image/imagenet_256/base.yaml -p sample

消融

你可以通过修改配置文件和命令行来修改实验设置。关于配置的更多细节在configs/README.md中。

例如,通过命令行更改采样时的无分类器引导比例:

python main.py -c configs/image/imagenet_256/base.yaml -p sample guidance_scale=1.5

测试

测试生成任务需要推理的结果。关于测试的更多细节在evaluations中。

🔒 许可

本项目的大部分内容根据LICENSE文件中的Apache 2.0许可发布。

✏️引用

如果你发现我们的代码在你的研究中有用,请考虑给一个星星⭐和引用📝。

@article{wang2024closer,
      title={A Closer Look at Time Steps is Worthy of Triple Speed-Up for Diffusion Model Training}, 
      author={Kai Wang, Yukun Zhou, Mingjia Shi, Zhihang Yuan, Yuzhang Shang, Xiaojiang Peng, Hanwang Zhang and Yang You},
      year={2024},
      journal={arXiv preprint arXiv:2405.17403},
}

👍 致谢

我们感谢Tianyi Li、Yuchen Zhang、Yuxin Li、Zhaoyang Zeng和Yanqing Liu对这项工作的评论。Kai Wang(想法、写作、故事、演示)、Yukun Zhou(实现)和Mingjia Shi(理论、写作、演示)对这项工作做出了同等贡献。Xiaojiang Peng、Hanwang Zhang和Yang You是平等的指导。Xiaojiang Peng是通讯作者。

我们感谢以下杰出的工作和对开源的慷慨贡献。

  • DiT:可扩展的基于Transformer的扩散模型。
  • Open-Sora:Open-Sora:为所有人民主化高效视频制作
  • OpenDiT:DiT训练的加速器。我们从OpenDiT采用了训练过程中有价值的加速策略。
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号