项目介绍:DAAM(针对稳定扩散的跨注意力解释)
项目概述
在这项名为“Diffusion Attentive Attribution Maps”(DAAM)的项目中,研究人员提出了一种基于跨注意力机制的方法,以解释“稳定扩散”(Stable Diffusion)图像生成模型。它通过生成归因地图来展示模型在生成过程中对于不同单词的重要性,帮助用户理解模型是如何从文本描述生成出对应的图像的。
功能与特性
支持复杂的图像生成:该项目已经升级以支持Stable Diffusion XL (SDXL)和Diffusers 0.21.1,能够处理更加复杂和细致的图像生成任务。
便捷的安装与使用:用户可以通过简单的命令来安装DAAM:pip install daam
。项目的灵活性也允许用户克隆代码库进行深入的研究和开发。
交互式演示:用户可以通过本地运行daam-demo
,在浏览器中体验DAAM的功能,与HuggingFace Spaces的在线演示效果一致。
命令行实用工具:DAAM提供了简单的命令行工具让用户快速体验。使用示例如下:
$ mkdir -p daam-test && cd daam-test
$ daam "A dog running across the field."
生成的结果包括图片和每个关键字的热力图。
库使用:开发者可以通过Python代码导入DAAM以进行更复杂的操作。用户可以按照以下方式使用DAAM:
from daam import trace, set_seed
from diffusers import DiffusionPipeline
from matplotlib import pyplot as plt
import torch
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
device = 'cuda'
pipe = DiffusionPipeline.from_pretrained(model_id, use_auth_token=True, torch_dtype=torch.float16, use_safetensors=True, variant='fp16')
pipe = pipe.to(device)
prompt = 'A dog runs across the field'
gen = set_seed(0)
with torch.no_grad():
with trace(pipe) as tc:
out = pipe(prompt, num_inference_steps=50, generator=gen)
heat_map = tc.compute_global_heat_map()
heat_map = heat_map.compute_word_heat_map('dog')
heat_map.plot_overlay(out.images[0])
plt.show()
实验数据管理:支持DAAM地图的序列化和反序列化,方便实验数据的保存与加载。
相关资源
- DAAM-i2i:这是DAAM应用于图像到图像归因的扩展项目。
- 学习资源:有多个视频教程帮助用户快速上手和了解DAAM的使用,提供不同版本的代码演示和Colab笔记本。
引用
如果您在研究中使用了DAAM,可以参考以下文献进行引用:
@inproceedings{tang2023daam,
title = "What the {DAAM}: Interpreting Stable Diffusion Using Cross Attention",
author = "Tang, Raphael and
Liu, Linqing and
Pandey, Akshat and
Jiang, Zhiying and
Yang, Gefei and
Kumar, Karun and
Stenetorp, Pontus and
Lin, Jimmy and
Ture, Ferhan",
booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
year = "2023",
}
结论
DAAM项目为用户提供了理解稳定扩散模型生成机制的全新视角,结合跨注意力机制的应用,用户可以更直观地了解模型的行为和单词在生成过程中的重要性。项目还提供易于使用的工具和教程资源,以帮助更多用户和研究人员快速上手。