Würstchen-Prior 项目介绍
Würstchen-Prior 是一个基于扩散模型的项目,其创新点在于能在高度压缩的图像潜在空间中工作。这意味着它可以在保持图像细节的情况下,大幅降低训练和推断的计算成本。相比于传统方法仅能实现4倍到8倍的空间压缩,Würstchen 实现了惊人的42倍空间压缩。而这要得益于其独特的设计,包括两个压缩阶段:阶段A和阶段B。
阶段介绍
阶段A和阶段B
阶段A使用了VQGAN(矢量量化生成对抗网络),而阶段B则使用了扩散自编码器。这两个阶段的配合不仅实现了高效的压缩,还保证了图像的细节能够被精确重建。
阶段C(Prior)
Prior,即阶段C,是一个文本条件模型,它在由阶段A和B编码的图像潜在空间中进行工作。在推断过程中,Prior负责根据给定的文本生成图像的潜在表示,然后这些潜在表示会被送回阶段A和B,以解码出像素级的图像。
图像尺寸与适应性
Würstchen 被训练在1024x1024到1536x1536的图像分辨率上,有时甚至可以在1024x2048的分辨率下得到不错的输出。阶段C对新的分辨率适应极快,因此在2048x2048分辨率下进行微调计算代价很低。
如何运行
可以使用以下示例代码来运行 Würstchen 管道。需要注意的是,这需要配合 https://huggingface.co/warp-ai/wuerstchen
一起使用:
import torch
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
device = "cuda"
dtype = torch.float16
num_images_per_prompt = 2
prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
"warp-ai/wuerstchen-prior", torch_dtype=dtype
).to(device)
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
"warp-ai/wuerstchen", torch_dtype=dtype
).to(device)
caption = "Anthropomorphic cat dressed as a fire fighter"
negative_prompt = ""
prior_output = prior_pipeline(
prompt=caption,
height=1024,
width=1536,
timesteps=DEFAULT_STAGE_C_TIMESTEPS,
negative_prompt=negative_prompt,
guidance_scale=4.0,
num_images_per_prompt=num_images_per_prompt,
)
decoder_output = decoder_pipeline(
image_embeddings=prior_output.image_embeddings,
prompt=caption,
negative_prompt=negative_prompt,
guidance_scale=0.0,
output_type="pil",
).images
模型细节
Würstchen 的开发者是 Pablo Pernias 和 Dominic Rampas。这是一个基于扩散的文本到图像生成模型,使用了固定的预训练文本编码器 (CLIP ViT-bigG/14)。该模型采用 MIT 许可证进行发布。
环境影响
根据机器学习影响计算器的估算,Würstchen v2在使用AWS的A100 PCIe 40GB硬件设备下,工作24602小时,预计产生了2275.68 kg的二氧化碳当量。
通过这一介绍,相信大家对 Würstchen-Prior 项目有了一个更直观与实际的理解。这个项目不仅在技术上具有创新性,而且在实际应用中也更具备高效性。