Project Icon

transformer_latent_diffusion

基于 PyTorch 的 Transformer 潜在扩散文本生图模型

Transformer Latent Diffusion 是一个基于 PyTorch 的开源项目,实现了文本到图像的潜在扩散模型。该模型体积小、生成速度快、性能合理,可在单 GPU 上快速训练。项目代码简洁,依赖少,注重数据质量。它提供数据处理工具,支持自定义训练,并进行了多项性能优化。项目展示了 256 分辨率随机样本和 CLIP 插值等生成示例。

Transformer 潜在扩散

在PyTorch中使用Transformer核心的自包含文本到图像潜在扩散。

尝试自己的输入: 在Colab中打开

以下是从一个从头开始训练260k迭代(在1个A100上约32小时)的100MM模型中生成的一些随机示例(256分辨率):

image

Clip插值示例:

一张猫的照片 → 一幅超级赛亚人猫的动漫绘画,artstation:

image

一只可爱的大灰猫头鹰 → 梵高的星夜:

image

请注意,该模型尚未收敛,还需要更多训练。

更高分辨率:

通过上采样位置编码,该模型还可以生成512或1024像素的图像,只需少量微调。以下是在额外100k张512像素图像和30k张1024像素图像上微调约2小时(在A100上)的模型示例。1024像素的图像有时缺乏全局连贯性 - 这里还会有更多内容:

image image

简介:

这个仓库的主要目标是构建一个可访问的PyTorch扩散模型,该模型:

  • 快速(接近实时生成)
  • 小巧(~100MM参数)
  • 合理良好(当然不是最先进的)
  • 可以在单个GPU上在合理的时间内训练(A100或同等设备上不到50小时)
  • 简单的自包含代码库(模型+训练循环约400行PyTorch代码,依赖很少)
  • 使用约100万张图像,注重数据质量而非数量,并提供下载和处理数据的代码

目录:

代码库:

代码使用纯PyTorch编写,尽可能减少依赖。

  • transformer_blocks.py - 与transformer去噪器相关的基本transformer构建块
  • denoiser.py - 去噪器transformer的架构
  • train.py。训练循环使用accelerate,因此可以根据需要扩展到多个GPU。
  • diffusion.py。使用反向扩散从噪声生成图像的类。简短(~60行)且自包含。
  • data.py。用于下载图像/文本并处理扩散模型所需特征的数据工具。

使用方法:

如果你有自己的URL+标题数据集,在数据上训练模型的过程包括两个步骤:

  1. 使用train.download_and_process_data获取潜在和文本编码作为numpy文件。参见在Colab中打开中的notebook示例,从这个HuggingFace数据集下载并处理2000张图像。

  2. 在accelerate notebook_launcher中使用train.main函数 - 参见在Colab中打开中的colab notebook,从头开始在10万张图像上训练模型。请注意,这会从这里下载已预处理的潜在变量和嵌入,但你也可以使用在步骤1中保存的任何.npy文件。

安装和依赖:

要安装软件包和依赖项,请运行: pip install git+https://github.com/apapiu/transformer_latent_diffusion.git

  • PyTorch numpy einops 用于模型构建
  • wandb tqdm 用于日志记录和进度条
  • accelerate 用于训练循环和多GPU支持
  • img2dataset webdataset torchvision 用于数据下载和图像处理
  • diffusers clip 用于预训练的VAE和CLIP文本模型

基本推理代码:

from tld.configs import LTDConfig, DenoiserConfig, TrainConfig
from tld.diffusion import DiffusionTransformer

denoiser_cfg = DenoiserConfig(n_channels=4) #在此配置你的模型
cfg = LTDConfig(denoiser_cfg=denoiser_cfg)

diffusion_transformer = DiffusionTransformer(cfg)

out = diffusion_transformer.generate_image_from_text(prompt="一只可爱的猫")

基本训练代码:

from tld.train import main
from tld.configs import ModelConfig, DataConfig

data_config = DataConfig(
    latent_path="latents.npy", text_emb_path="text_emb.npy", val_path="val_emb.npy"
)

model_cfg = ModelConfig(
    data_config=data_config,
    train_config=TrainConfig(n_epoch=100, save_model=False, compile=False, use_wandb=False),
)

main(model_cfg)

#或者在笔记本中在2个GPU上运行训练过程:
#notebook_launcher(main, model_cfg, num_processes=2)

测试:

test_diffuser.py中的测试是开始理解代码的好地方。你可以通过运行pytest -s来运行所有测试。

Github Actions:

我配置了一些github action来运行测试、检查代码风格并构建一些docker镜像 - 如果你只是探索代码,你可以注释掉这些或删除.github/workflows文件夹。

配置:

配置在tld/configs.py中以数据类的形式存在。默认值总是可以被覆盖。例如:DenoiserConfig(n_layers=16)保留所有默认值,除了n_layers。你也可以将配置保存为JSON并像这样加载:DenoiserConfig(**json.load(file))

速度:

我尽可能地加快训练和推理速度,通过:

  • 使用混合精度进行训练 + [sdpa]
  • 预计算所有潜在和文本嵌入
  • 使用float16精度进行推理
  • 使用[sdpa]进行闪光注意力2 + 在PyTorch 2.0+上使用torch.compile()
  • 使用高性能采样器(DPM-Solver++(2M)),可以在约15步内获得良好结果。

生成36张图片批次(15次迭代)的时间在:

  • T4:约3.5秒
  • A100:约0.6秒 事实上,在A100上,VAE成为了瓶颈,尽管它只使用一次。

用于灵感的代码库:

示例:

更多使用100MM模型生成的示例 - 点击照片查看提示和其他参数,如cfg和种子: [图片链接]

外绘模型:

我还在原始101MM模型的基础上微调了一个外绘模型。我必须修改原始输入conv2d补丁为8通道,并将掩码通道参数初始化为零。其余架构保持不变。

下面我重复应用外绘模型,根据提示"一个赛博朋克市场"生成一个somewhat一致的场景:

[图片链接]

数据处理:

data.py中,我有一些helper函数来处理图像和标题。流程如下:

  • 使用img2dataset从包含URL和标题的数据框中下载图像。
  • 使用CLIP对提示进行编码,使用VAE在web2dataset数据生成器上对图像进行编码为潜在表示。
  • 保存潜在表示和文本嵌入以供未来训练使用。 这种方法有两个优点。一是VAE编码计算成本较高,如果每个epoch都进行编码会影响训练时间。二是我们可以在处理后丢弃图像。对于3256256的图像,潜在维度是43232,所以每个潜在变量大约4KB(使用uint8量化;参见这里)。这意味着100万个潜在变量的大小"仅"为4GB,即使在内存中也很容易处理。存储原始图像的大小会大48倍。

架构:

denoiser类的代码见这里。

denoiser模型是基于DiT和Pixart-Alpha架构的Transformer模型,尽管做了不少修改和简化。使用Transformer作为denoiser与大多数扩散模型不同,因为其他模型主要使用基于CNN的U-NET作为去噪主干。我决定使用Transformer有几个原因。一是我想从头开始实验和学习如何构建和训练Transformer。其次,Transformer在训练和推理上都很快,并且将从未来的性能进步(硬件和软件)中受益最多。

Transformer本身并不是为空间数据而设计的,起初我发现很多输出都很"斑驳"。为了改善这一点,我在transformer的FFN层中添加了深度卷积(这在Local ViT论文中引入)。这允许模型以很小的计算成本混合相邻的像素。

图像+文本+噪声编码:

图像潜在输入为43232,我们使用2的patch大小来构建256个展平的422=16维输入"像素"。然后将这些投影到嵌入维度并输入transformer块。

文本和噪声条件非常简单 - 我们将池化的CLIP文本嵌入(ViT/L14 - 768维)和正弦噪声嵌入连接起来,作为每个transformer块中交叉注意力层的输入。不使用未池化的CLIP嵌入。

训练:

基础模型有1.01亿参数,12层,嵌入维度为768。我在A100上使用256的批量大小和3e-4的学习率进行训练。预热使用1000步。由于计算限制,我没有对这个配置进行消融实验。

训练和扩散设置:

我们训练一个去噪transformer,它接受以下三个输入:

  • noise_level(从0到1采样,更多值集中在接近0处 - 我使用beta分布)
  • 用随机噪声污染的图像潜在变量(x)
    • 对于0到1之间的给定noise_level,污染如下:
      • x_noisy = x*(1-noise_level) + eps*noise_level,其中eps ~ np.random.normal(0, 1)
  • 文本提示的CLIP嵌入
    • 可以将其视为文本提示的数值表示。
    • 这里我们使用池化的文本嵌入(ViT/L14为768维)

输出是去噪后图像潜在变量的预测 - 称之为f(x_noisy)

模型被训练以最小化预测和实际图像之间的均方误差|f(x_noisy) - x| (这里也可以使用绝对误差)。注意,为了保持简单,我没有在这里重新参数化噪声的损失。

使用这个模型,我们然后通过以下方式迭代地从随机噪声生成图像:

for i in range(len(self.noise_levels) - 1):

  curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]

  # 预测原始去噪图像:
  x0_pred = predict_x_zero(new_img, label, curr_noise)

  # next_noise级别的新图像是旧图像和预测x0的加权平均:
  new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise

predict_x_zero方法通过结合条件和无条件预测使用无分类器引导:x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional

一些数学:上述方法属于VDM参数化,见Kingma等人论文的3.1节:

$$z_t = \alpha_t x + \sigma_t \epsilon, \epsilon \sim \mathcal{N}(0,1)$$

其中$z_t$是时间t时x的噪声版本。

通常,$\alpha_t$被选为$\sqrt{1-\sigma_t^2}$,使得过程保持方差。这里,我选择$\alpha_t=1-\sigma_t$,以便在图像和随机噪声之间线性插值。为什么?首先,它大大简化了更新方程,更容易理解噪声信号比将是什么样子。我还发现模型能更快地生成清晰的图像。上面的更新方程是这种参数化的DDIM模型,简化为简单的加权平均。注意,DDIM模型确定性地将随机正态噪声映射到图像 - 这有两个好处:我们可以在随机正态潜在空间中进行插值,通常需要更少的步骤就能达到不错的图像质量。

待办事项:

  • [] 如何进一步加快生成速度 - LCMs?
  • [] 添加计算FID的脚本
  • 改进训练文件中的配置
  • 更快的采样 - DDPM
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号