Project Icon

gigagan-pytorch

最新生成对抗网络GigaGAN的实现,优化训练收敛和模型稳定性

gigagan-pytorch项目实现了Adobe最新的生成对抗网络GigaGAN,优化了跳层激励和辅助重建损失,以提升训练收敛速度和模型稳定性。项目支持高分辨率上采样器,具备混合精度和多GPU训练功能。适合寻求高效稳定GAN训练的开发者和研究人员。可加入Discord社区,与LAION合作获取更多支持。

GigaGAN - Pytorch

实现 GigaGAN (项目页面),Adobe 最新的 SOTA GAN。

我还会添加一些来自轻量级 gan的发现,以加快收敛(跳层激励)和更好的稳定性(判别器中的辅助重建损失)

它还将包含 1k - 4k 上采样器的代码,这是我认为这篇论文的亮点。

如果您有兴趣与 LAION 社区一起帮助复制,请加入 Join us on Discord

感谢

  • 感谢 StabilityAI🤗 Huggingface 的慷慨赞助,以及我的其他赞助商们,使我能够独立开源人工智能。

  • 感谢 🤗 Huggingface 的 accelerate 库

  • 感谢 OpenClip 的所有维护者,他们的 SOTA 开源对比学习文本-图像模型

  • 感谢 Xavier 的非常有帮助的代码审查,以及关于如何构建判别器中的尺度不变性进行的讨论!

  • 感谢 @CerebralSeed 提出生成器和上采样器初始采样代码的拉取请求!

  • 感谢 Keerth 的代码审查和指出与论文的一些差异!

安装

$ pip install gigagan-pytorch

使用方法

简单的无条件 GAN,供初学者使用

import torch

from gigagan_pytorch import (
    GigaGAN,
    ImageDataset
)

gan = GigaGAN(
    generator = dict(
        dim_capacity = 8,
        style_network = dict(
            dim = 64,
            depth = 4
        ),
        image_size = 256,
        dim_max = 512,
        num_skip_layers_excite = 4,
        unconditional = True
    ),
    discriminator = dict(
        dim_capacity = 16,
        dim_max = 512,
        image_size = 256,
        num_skip_layers_excite = 4,
        unconditional = True
    ),
    amp = True
).cuda()

# 数据集

dataset = ImageDataset(
    folder = '/path/to/your/data',
    image_size = 256
)

dataloader = dataset.get_dataloader(batch_size = 1)

# 在训练前必须为 GAN 设置数据加载器

gan.set_dataloader(dataloader)

# 交替训练判别器和生成器
# 在这个例子中训练 100 步,批量大小 1,梯度累积 8 次

gan(
    steps = 100,
    grad_accum_every = 8
)

# 经过大量训练后

images = gan.generate(batch_size = 4) # (4, 3, 256, 256)

对于无条件 Unet 上采样器

import torch
from gigagan_pytorch import (
    GigaGAN,
    ImageDataset
)

gan = GigaGAN(
    train_upsampler = True,     # 将其设置为 True
    generator = dict(
        style_network = dict(
            dim = 64,
            depth = 4
        ),
        dim = 32,
        image_size = 256,
        input_image_size = 64,
        unconditional = True
    ),
    discriminator = dict(
        dim_capacity = 16,
        dim_max = 512,
        image_size = 256,
        num_skip_layers_excite = 4,
        multiscale_input_resolutions = (128,),
        unconditional = True
    ),
    amp = True
).cuda()

dataset = ImageDataset(
    folder = '/path/to/your/data',
    image_size = 256
)

dataloader = dataset.get_dataloader(batch_size = 1)

gan.set_dataloader(dataloader)

# 交替训练判别器和生成器
# 在这个例子中训练 100 步,批量大小 1,梯度累积 8 次

gan(
    steps = 100,
    grad_accum_every = 8
)

# 经过大量训练后

lowres = torch.randn(1, 3, 64, 64).cuda()

images = gan.generate(lowres) # (1, 3, 256, 256)

损失

  • G - 生成器
  • MSG - 多尺度生成器
  • D - 判别器
  • MSD - 多尺度判别器
  • GP - 梯度惩罚
  • SSL - 判别器中的辅助重建(来自轻量级 GAN)
  • VD - 视觉辅助判别器
  • VG - 视觉辅助生成器
  • CL - 生成器对比损失
  • MAL - 匹配感知损失

一个健康的运行应该使 GMSGDMSD 的值保持在010之间,并且通常保持相当稳定。如果在 1k 训练步后这些值仍然保持在三位数,那就说明出了问题。生成器和判别器的值偶尔下降为负数是可以的,但它应该恢复到上述范围内。

GPSSL 应该朝0推动。GP可能会偶尔飙升;我喜欢把这想象成网络经历了一些顿悟

多 GPU 训练

GigaGAN 类现在配有 🤗 Accelerator。您可以使用 accelerate CLI 轻松完成两步多 GPU 训练

在项目的根目录下,训练脚本所在的位置,运行

$ accelerate config

然后,在同一目录下

$ accelerate launch train.py

待办事项

  • 确保它可以无条件训练

  • 阅读相关论文并完成所有 3 个辅助损失

    • 匹配感知损失
    • clip 损失
    • 视觉辅助判别器损失
    • 在判别器的任意阶段添加重建损失(轻量级 gan)
    • 弄清随机投影如何在投影 gan 中使用
    • 视觉辅助判别器需要从 CLIP 中提取 N 层
    • 弄清楚是丢弃 CLS 标记并重塑为图像尺寸用于卷积,还是坚持使用注意力并用自适应层归一化进行调节 - 还要在无条件情况下关闭视觉辅助 gan
  • unet 上采样器

    • 添加自适应卷积
    • 修改 unet 的后期输出 RGB 残差,并将 RGB 传递给判别器。使判别器对传入的 RGB 具有鲁棒性
    • 对 unet 进行像素洗牌上采样
  • 对多尺度输入和输出进行代码审查,因为论文有点模糊

  • 添加上采样网络架构

  • 确保无条件工作于基础生成器和上采样器

  • 确保文本条件训练工作于基础生成器和上采样器

  • 通过随机采样补丁使重建更高效

  • 确保生成器和判别器也能接受预编码的 CLIP 文本编码

  • 审查辅助损失

    • 为生成器添加对比损失
    • 添加视觉辅助损失
    • 为视觉辅助判别器添加梯度惩罚 - 可选
    • 添加匹配感知损失 - 弄清旋转文本条件是否足以用于错配(无需从数据加载器中抽取额外批次)
    • 确保与匹配感知损失一起工作的梯度累积
    • 匹配感知损失运行且稳定
    • 视觉辅助训练
  • 添加一些可微分增强技术,这是旧 GAN 时代的已验证技巧

    • 移除任何自动 RGB 处理的魔法,并显式传递它 - 为判别器提供处理真实图像到正确多尺度的函数
    • 首先添加水平翻转
  • 将所有调制投影移动到 adaptive conv2d 类中

  • 添加加速

    • 单机工作
    • 混合精度工作(确保梯度惩罚正确缩放),注意手动缩放器的保存和重新加载,从 imagen-pytorch 借用
    • 确保多 GPU 对于单机工作
    • 让别人尝试多台机器
  • clip 应该对所有模块都是可选的,并由 GigaGAN 管理,用一次处理的文本 -> 文本嵌入

  • 添加选择多尺度维度的随机子集的能力,以提高效率

  • 从 lightweight|stylegan2-pytorch 移植 CLI

  • 挂接 laion 数据集以进行文本-图像

引用

@misc{https://doi.org/10.48550/arxiv.2303.05511,
    url     = {https://arxiv.org/abs/2303.05511},
    author  = {Kang, Minguk and Zhu, Jun-Yan and Zhang, Richard and Park, Jaesik and Shechtman, Eli and Paris, Sylvain and Park, Taesung},  
    title   = {Scaling up GANs for Text-to-Image Synthesis},
    publisher = {arXiv},
    year    = {2023},
    copyright = {arXiv.org perpetual, non-exclusive license}
}
@article{Liu2021TowardsFA,
    title   = {Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis},
    author  = {Bingchen Liu and Yizhe Zhu and Kunpeng Song and A. Elgammal},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2101.04775}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Karras2020ada,
    title     = {Training Generative Adversarial Networks with Limited Data},
    author    = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
    booktitle = {Proc. NeurIPS},
    year      = {2020}
}
@article{Xu2024VideoGigaGANTD,
    title   = {VideoGigaGAN: Towards Detail-rich Video Super-Resolution},
    author  = {Yiran Xu and Taesung Park and Richard Zhang and Yang Zhou and Eli Shechtman and Feng Liu and Jia-Bin Huang and Difan Liu},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2404.12388},
    url     ={https://api.semanticscholar.org/CorpusID:269214195}
}
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号