FreeU: 扩散U-Net中的免费午餐
南洋理工大学S-Lab
CVPR2024 口头报告
我们提出了FreeU,一种可以显著提高扩散模型样本质量的方法,而且完全免费:无需训练,无需引入额外参数,也不会增加内存或采样时间。
:open_book: 更多视觉结果,请查看我们的项目主页
使用方法
- 演示也可在上使用(非常感谢AK和所有HF团队的支持)。
- 您可以通过运行
python demos/app.py
在本地使用gradio演示。
FreeU 代码
def Fourier_filter(x, threshold, scale):
# FFT
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W)).cuda()
crow, ccol = H // 2, W //2
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
x_freq = x_freq * mask
# IFFT
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
return x_filtered
class Free_UNetModel(UNetModel):
"""
:param b1: 解码器第一阶段块的主干因子。
:param b2: 解码器第二阶段块的主干因子。
:param s1: 解码器第一阶段块的跳跃因子。
:param s2: 解码器第二阶段块的跳跃因子。
"""
def __init__(
self,
b1,
b2,
s1,
s2,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.b1 = b1
self.b2 = b2
self.s1 = s1
self.s2 = s2
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
将模型应用于输入批次。
:param x: [N x C x ...] 形状的输入张量。
:param timesteps: 1维时间步批次。
:param context: 通过交叉注意力插入的条件。
:param y: [N] 形状的标签张量,如果是类别条件模型。
:return: [N x C x ...] 形状的输出张量。
"""
assert (y is not None) == (
self.num_classes is not None
), "必须且仅当模型是类别条件时指定y"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
hs_ = hs.pop()
# --------------- FreeU 代码 -----------------------
# 仅对前两个阶段进行操作
if h.shape[1] == 1280:
hidden_mean = h.mean(1).unsqueeze(1)
B = hidden_mean.shape[0]
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
h[:,:640] = h[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
if h.shape[1] == 640:
hidden_mean = h.mean(1).unsqueeze(1)
B = hidden_mean.shape[0]
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
h[:,:320] = h[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
# ---------------------------------------------------------
h = th.cat([h, hs_], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
参数
根据您的模型、图像/视频风格或任务,可以自由调整这些参数。以下参数仅供参考。
SD1.4:(即将更新)
b1: 1.3, b2: 1.4, s1: 0.9, s2: 0.2
SD1.5:(即将更新)
b1: 1.5, b2: 1.6, s1: 0.9, s2: 0.2
SD2.1
b1: 1.1, b2: 1.2, s1: 0.9, s2: 0.2
b1: 1.4, b2: 1.6, s1: 0.9, s2: 0.2
SDXL
b1: 1.3, b2: 1.4, s1: 0.9, s2: 0.2 SDXL 结果
更多参数的范围
在尝试额外参数时,可以考虑以下范围:
- b1: 1 ≤ b1 ≤ 1.2
- b2: 1.2 ≤ b2 ≤ 1.6
- s1: s1 ≤ 1
- s2: s2 ≤ 1
社区结果
如果您尝试了FreeU并想分享您的结果,请告诉我,我们可以在这里放上链接。
- SDXL 来自 Nasir Khalid
- comfyUI 来自 Abraham
- SD2.1 来自 Justin DuJardin
- SDXL 来自 Sebastian
- SDXL 来自 tintwotin
- ComfyUI-FreeU(YouTube)
- ComfyUI-FreeU(中文)
- Rerender
- Collaborative-Diffusion
BibTeX
@inproceedings{si2023freeu,
title={FreeU: Free Lunch in Diffusion U-Net},
author={Si, Chenyang and Huang, Ziqi and Jiang, Yuming and Liu, Ziwei},
booktitle={CVPR},
year={2024}
}
:newspaper_roll: 许可证
根据MIT许可证分发。有关更多信息,请参阅LICENSE
。