Würstchen项目介绍
项目概述
Würstchen是一个创新的扩散模型,专门用于根据文本生成图像。它的独特之处在于其在高度压缩的潜在图像空间中工作。这种高效的数据压缩大大降低了训练和推断的计算成本。传统的图像分辨率如1024x1024的训练需要很高的资源消耗,而Würstchen通过其独特的设计实现了惊人的42倍空间压缩。通常,其他方法在超过16倍空间压缩时会出现图像细节的丢失,而Würstchen通过两阶段压缩策略成功克服了这一难题。第一阶段(Stage A)和第二阶段(Stage B)分别是VQGAN和扩散自动编码器。第三个模型(Stage C)则在高度压缩的潜在空间中进行训练。这种结构使得模型比当前的顶尖模型使用更少的计算资源,实现更快速和经济的推断。
解码器
解码器是Würstchen中的关键组件,它包括Stage A和Stage B。当解码器接收到图像嵌入(这些嵌入可能是Prior阶段生成的,也可能是从真实图像中提取的),它能将这些潜在编码还原成像素图像。具体来说,Stage B负责将图像嵌入解码到VQGAN空间,而Stage A会将潜在编码解码为最终的像素图像。两者结合,实现了42倍的空间压缩。
注意
目前图像重建过程中会存在损耗,尤其在人脸、手部等细节上,我们肉眼可能会明显察觉这一点。项目团队正在努力提高未来的重建质量。
图像尺寸
Würstchen模型在1024x1024到1536x1536之间的图像分辨率上进行了训练,我们也观察到在1024x2048等分辨率下能产生理想输出。用户可以根据需要自行尝试不同分辨率。Prior(Stage C)对新分辨率适应非常迅速,因此对2048x2048的细化调整应该是计算有效的。
如何使用
要运行这个流水线,你需要结合前阶段组件:
import torch
from diffusers import AutoPipelineForText2Image
device = "cuda"
dtype = torch.float16
pipeline = AutoPipelineForText2Image.from_pretrained(
"warp-diffusion/wuerstchen", torch_dtype=dtype
).to(device)
caption = "Anthropomorphic cat dressed as a fire fighter"
output = pipeline(
prompt=caption,
height=1024,
width=1024,
prior_guidance_scale=4.0,
decoder_guidance_scale=0.0,
).images
图像采样时间
在不同批量大小下,Würstchen和Stable Diffusion XL的推断时间进行对比(使用A100显卡)。图示左侧显示使用torch > 2.0时的推断时间,右侧则是在预先应用torch.compile
后的推断时间。
模型详情
- 开发者: Pablo Pernias, Dominic Rampas
- 模型类型: 基于扩散的文本到图像生成模型
- 语言: 英语
- 许可证: MIT
- 模型描述: 该模型可用于根据文本提示生成和修改图像。它是Würstchen论文中Stage C风格的扩散模型,使用预训练的文本编码器。
环境影响
Würstchen v2 估计排放
基于提供的信息,使用机器学习影响计算器估算出大约的二氧化碳排放:
- 硬件类型: A100 PCIe 40GB
- 使用时长: 24602小时
- 云提供商: AWS
- 计算区域: 美国东部
- 二氧化碳排放: 2275.68 kg CO2 eq.