Project Icon

gan-compression

条件生成对抗网络的高效压缩技术

GAN Compression项目提出了一种通用的条件生成对抗网络压缩方法,可将pix2pix、CycleGAN等模型的计算量减少9-29倍,同时保持视觉质量。该方法适用于多种生成器架构和学习目标,支持配对和非配对数据。项目开源了预训练模型、演示和教程,便于研究和应用。

GAN 压缩

项目主页 | 论文 | 视频 | 幻灯片

[新消息!] GAN 压缩已被 T-PAMI 接收! 我们在 arXiv v4 发布了 T-PAMI 版本!

[新消息!] 我们发布了交互式演示的代码,并包含了用 TVM 调优的模型。现在在 Jetson Nano GPU 上可以达到 8FPS 的速度!

[新消息!] 新增对 MUNIT 的支持,这是一种多模态无监督图像到图像转换方法! 请按照测试命令测试预训练模型,并参考教程训练您自己的模型!

示意图 我们提出了 GAN 压缩,这是一种压缩条件 GAN 的通用方法。我们的方法可以将广泛使用的条件 GAN 模型(包括 pix2pix、CycleGAN、MUNIT 和 GauGAN)的计算量减少 9-29 倍,同时保持视觉保真度。我们的方法对各种生成器架构、学习目标以及配对和非配对设置都非常有效。

GAN 压缩: 面向交互式条件 GAN 的高效架构
李牧阳林吉丁瑶瑶刘志坚朱俊彦韩松
麻省理工学院、Adobe 研究院、上海交通大学
CVPR 2020 会议论文。

演示

概述

概述GAN 压缩框架: ① 给定一个预训练的教师生成器 G',我们通过权重共享蒸馏出一个更小的"一劳永逸"学生生成器 G,其中包含所有可能的通道数。我们在每个训练步骤中为学生生成器 G 选择不同的通道数。② 然后,我们从"一劳永逸"生成器中提取许多子生成器并评估它们的性能。无需重新训练,这是"一劳永逸"生成器的优势。③ 最后,我们根据压缩比目标和性能目标(FID 或 mIoU),使用暴力搜索或进化搜索方法选择最佳子生成器。可选地,我们进行额外的微调,得到最终的压缩模型。

性能

性能

GAN 压缩将 pix2pix、cycleGAN 和 GauGAN 的计算量减少了 9-21 倍,模型大小减少了 4.6-33 倍。

Colab 笔记本

PyTorch Colab 笔记本: CycleGANpix2pix

先决条件

  • Linux
  • Python 3
  • CPU 或 NVIDIA GPU + CUDA CuDNN

入门指南

安装

  • 克隆此仓库:

    git clone git@github.com:mit-han-lab/gan-compression.git
    cd gan-compression
    
  • 安装 PyTorch 1.4 和其他依赖项(如 torchvision)。

    • 对于 pip 用户,请输入命令 pip install -r requirements.txt
    • 对于 Conda 用户,我们提供了安装脚本 scripts/conda_deps.sh。或者,您可以使用 conda env create -f environment.yml 创建新的 Conda 环境。

CycleGAN

设置

  • 下载 CycleGAN 数据集(例如,horse2zebra)。

    bash datasets/download_cyclegan_dataset.sh horse2zebra
    
  • 获取数据集真实图像的统计信息以计算 FID。我们为几个数据集提供了预准备的真实统计信息。例如,

    bash datasets/download_real_stat.sh horse2zebra A
    bash datasets/download_real_stat.sh horse2zebra B
    

应用预训练模型

  • 下载预训练模型。

    python scripts/download_model.py --model cycle_gan --task horse2zebra --stage full
    python scripts/download_model.py --model cycle_gan --task horse2zebra --stage compressed
    
  • 测试原始完整模型。

    bash scripts/cycle_gan/horse2zebra/test_full.sh
    
  • 测试压缩模型。

    bash scripts/cycle_gan/horse2zebra/test_compressed.sh
    
  • 测量两个模型的延迟。

    bash scripts/cycle_gan/horse2zebra/latency_full.sh
    bash scripts/cycle_gan/horse2zebra/latency_compressed.sh
    
  • 由于我们重新训练了模型,上述模型的结果可能与论文中的结果略有不同。我们还发布了论文中的压缩模型。如果存在这样的不一致,您可以尝试以下命令来测试我们的论文模型:

    python scripts/download_model.py --model cycle_gan --task horse2zebra --stage legacy
    bash scripts/cycle_gan/horse2zebra/test_legacy.sh
    bash scripts/cycle_gan/horse2zebra/latency_legacy.sh
    

Pix2pix

设置

  • 下载 pix2pix 数据集(例如,edges2shoes)。

    bash datasets/download_pix2pix_dataset.sh edges2shoes-r
    
  • 获取数据集真实图像的统计信息以计算 FID。我们为几个数据集提供了预准备的真实统计信息。例如,

    bash datasets/download_real_stat.sh edges2shoes-r B
    bash datasets/download_real_stat.sh edges2shoes-r subtrain_B
    

应用预训练模型

  • 下载预训练模型。

    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage full
    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage compressed
    
  • 测试原始完整模型。

    bash scripts/pix2pix/edges2shoes-r/test_full.sh
    
  • 测试压缩模型。

    bash scripts/pix2pix/edges2shoes-r/test_compressed.sh
    
  • 测量两个模型的延迟。

    bash scripts/pix2pix/edges2shoes-r/latency_full.sh
    bash scripts/pix2pix/edges2shoes-r/latency_compressed.sh
    
  • 由于我们重新训练了模型,上述模型的结果可能与论文中的结果略有不同。我们还发布了论文中的压缩模型。如果存在这样的不一致,您可以尝试以下命令来测试我们的论文模型:

    python scripts/download_model.py --model pix2pix --task edges2shoes-r --stage legacy
    bash scripts/pix2pix/edges2shoes-r/test_legacy.sh
    bash scripts/pix2pix/edges2shoes-r/latency_legacy.sh
    

GauGAN

设置

  • 准备 cityscapes 数据集。查看此处以准备 cityscapes 数据集。

  • 获取数据集真实图像的统计信息以计算 FID。我们为几个数据集提供了预准备的真实统计信息。例如,

    bash datasets/download_real_stat.sh cityscapes A
    

应用预训练模型

  • 下载预训练模型。

    python scripts/download_model.py --model gaugan --task cityscapes --stage full
    python scripts/download_model.py --model gaugan --task cityscapes --stage compressed
    
  • 测试原始完整模型。

    bash scripts/gaugan/cityscapes/test_full.sh
    
  • 测试压缩模型。

    bash scripts/gaugan/cityscapes/test_compressed.sh
    
  • 测量两个模型的延迟。

    bash scripts/gaugan/cityscapes/latency_full.sh
    bash scripts/gaugan/cityscapes/latency_compressed.sh
    
  • 由于我们重新训练了模型,上述模型的结果可能与论文中的结果略有不同。我们还发布了论文中的压缩模型。如果存在这样的不一致,您可以尝试以下命令来测试我们的论文模型:

    python scripts/download_model.py --model gaugan --task cityscapes --stage legacy
    bash scripts/gaugan/cityscapes/test_legacy.sh
    bash scripts/gaugan/cityscapes/latency_legacy.sh
    

MUNIT

设置

  • 准备数据集(如edges2shoes-r)。

    bash datasets/download_pix2pix_dataset.sh edges2shoes-r
    python datasets/separate_A_and_B.py --input_dir database/edges2shoes-r --output_dir database/edges2shoes-r-unaligned
    python datasets/separate_A_and_B.py --input_dir database/edges2shoes-r --output_dir database/edges2shoes-r-unaligned --phase val
    
  • 获取数据集真实图像的统计信息以计算FID。我们为多个数据集提供了预先准备好的真实统计数据。例如:

    bash datasets/download_real_stat.sh edges2shoes-r B
    bash datasets/download_real_stat.sh edges2shoes-r-unaligned subtrain_B
    

应用预训练模型

  • 下载预训练模型。

    python scripts/download_model.py --model gaugan --task cityscapes --stage full
    python scripts/download_model.py --model gaugan --task cityscapes --stage compressed
    
  • 测试原始完整模型。

    bash scripts/munit/edges2shoes-r_fast/test_full.sh
    
  • 测试压缩模型。

    bash scripts/munit/edges2shoes-r_fast/test_compressed.sh
    
  • 测量两个模型的延迟。

    bash scripts/munit/edges2shoes-r_fast/latency_full.sh
    bash scripts/munit/edges2shoes-r_fast/latency_compressed.sh
    

Cityscapes数据集

由于许可问题,我们无法提供Cityscapes数据集。请从https://cityscapes-dataset.com下载数据集,并使用脚本[prepare_cityscapes_dataset.py](datasets/prepare_cityscapes_dataset.py)进行预处理。您需要下载gtFine_trainvaltest.zipleftImg8bit_trainvaltest.zip,并将它们解压到同一文件夹中。例如,您可以将gtFineleftImg8bit放在database/cityscapes-origin中。您需要使用以下命令准备数据集:

python datasets/get_trainIds.py database/cityscapes-origin/gtFine/
python datasets/prepare_cityscapes_dataset.py \
--gtFine_dir database/cityscapes-origin/gtFine \
--leftImg8bit_dir database/cityscapes-origin/leftImg8bit \
--output_dir database/cityscapes \
--train_table_path datasets/train_table.txt \
--val_table_path datasets/val_table.txt

您将在database/cityscapes中获得预处理后的数据集,以及dataset/table.txt中的映射表(用于计算mIoU)。

为了支持mIoU计算,您需要从http://go.yf.io/drn-cityscapes-models下载预训练的DRN模型`drn-d-105_ms_cityscapes.pth`。默认情况下,我们将drn模型放在仓库的根目录中。然后,在下载我们的模型后,您就可以在cityscapes上测试我们的压缩模型了。

COCO-Stuff数据集

我们遵循与NVlabs/spade相同的COCO-Stuff数据集准备方法。具体来说,您需要从nightrome/cocostuff下载train2017.zipval2017.zipstuffthingmaps_trainval2017.zipannotations_trainval2017.zip。图像、标签和实例地图应按照datasets/coco_stuff中的相同目录结构排列。特别地,我们使用了一个结合了"物体实例地图"和"物质标签地图"边界的实例地图。为此,我们使用了一个简单的脚本datasets/coco_generate_instance_map.py

为了支持mIoU计算,您需要下载预训练的DeeplabV2模型deeplabv2_resnet101_msc-cocostuff164k-100000.pth,并将其也放在仓库的根目录中。

已发布模型的性能

以下我们展示了所有已发布模型的性能:

模型数据集方法参数数量MACs评估指标
FIDmIoU
CycleGAN马→斑马原始11.4M56.8G65.75--
GAN压缩(论文)0.342M2.67G65.33--
GAN压缩(重新训练)0.357M2.55G65.12--
快速GAN压缩0.355M2.64G65.19--
Pix2pix边缘→鞋子原始11.4M56.8G24.12--
GAN压缩(论文)0.700M4.81G26.60--
GAN压缩(重新训练)0.822M4.99G26.70--
快速GAN压缩0.703M4.83G25.76--
城市景观原始11.4M56.8G--42.06
GAN压缩(论文)0.707M5.66G--40.77
GAN压缩(重新训练)0.781M5.59G--38.63
快速GAN压缩0.867M5.61G--41.71
地图→航拍照片
原始11.4M56.8G47.91--
GAN压缩0.746M4.68G48.02--
快速GAN压缩0.708M4.53G48.67--
GauGAN城市景观原始93.0M281G57.6061.04
GAN压缩(论文)20.4M31.7G55.1961.22
GAN压缩(重新训练)21.0M31.2G56.4360.29
快速GAN压缩20.2M31.3G56.2561.17
COCO-Stuff原始97.5M191G21.3838.78
快速GAN压缩26.0M35.5G25.0635.05
MUNIT边缘→鞋子原始15.0M77.3G30.13--
快速GAN压缩1.10M2.63G30.53--

训练

请参考快速GAN压缩GAN压缩的教程,了解如何在我们的数据集和您自己的数据集上训练模型。

FID计算

要计算FID分数,您需要从数据集的真实图像中获取一些统计信息。我们提供了一个脚本get_real_stat.py来提取统计信息。例如,对于edges2shoes数据集,您可以运行以下命令:

python get_real_stat.py \
--dataroot database/edges2shoes-r \
--output_path real_stat/edges2shoes-r_B.npz \
--direction AtoB

对于成对的图像到图像转换(pix2pix和GauGAN),我们计算生成的测试图像与真实测试图像之间的FID。对于非成对的图像到图像转换(CycleGAN),我们计算生成的测试图像与真实训练+测试图像之间的FID。这允许我们使用更多的图像进行稳定的FID评估,就像之前的无条件GAN研究中所做的那样。这两种协议的差异很小。当使用真实测试图像而不是真实训练+测试图像时,我们压缩的CycleGAN模型的FID增加了4。

代码结构

为了帮助用户更好地理解和使用我们的代码,我们简要概述了每个包和每个模块的功能和实现。

引用

如果您在研究中使用了这份代码,请引用我们的论文

@inproceedings{li2020gan,
  title={GAN Compression: Efficient Architectures for Interactive Conditional GANs},
  author={Li, Muyang and Lin, Ji and Ding, Yaoyao and Liu, Zhijian and Zhu, Jun-Yan and Han, Song},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2020}
}

致谢

我们的代码是基于pytorch-CycleGAN-and-pix2pixSPADEMUNIT开发的。

我们还要感谢pytorch-fid用于FID计算,drn用于城市景观mIoU计算,以及deeplabv2用于Coco-Stuff mIoU计算。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号