EdgeConnect: 基于对抗性边缘学习的生成图像修补
介绍:
我们开发了一种新的图像修补方法,通过模拟艺术家的工作方式(先画线,再填色)更好地再现了充满细节的填充区域。我们提出一个两阶段的对抗模型EdgeConnect,由边缘生成器和图像完成网络组成。边缘生成器生成图像缺失区域(包括规则和不规则)的边缘,图像完成网络则使用这些生成的边缘作为先验来填补缺失区域。系统的详细描述请参考我们的论文。
(a) 输入图像,显示缺失区域。缺失区域用白色表示。(b) 计算出的边缘掩码。黑色边缘是使用Canny边缘检测器计算出来的(对于可用区域),而蓝色边缘是由边缘生成器网络生成的。(c) 使用所提方法的图像修补结果。
前置条件
- Python 3
- PyTorch 1.0
- NVIDIA GPU + CUDA cuDNN
安装
- 克隆此仓库:
git clone https://github.com/knazeri/edge-connect.git
cd edge-connect
- 从 http://pytorch.org 安装 PyTorch 及其依赖项。
- 安装 Python 需求:
pip install -r requirements.txt
数据集
1) 图像
我们使用Places2,CelebA 和Paris Street-View 数据集。要在完整数据集上训练模型,请从官方网站下载数据集。
下载完毕后,运行scripts/flist.py
生成训练、测试和验证集文件列表。例如,要在Places2数据集上生成训练集文件列表,运行:
mkdir datasets
python ./scripts/flist.py --path path_to_places2_train_set --output ./datasets/places_train.flist
2) 不规则掩码
我们的模型在Liu等提供的不规则掩码数据集上进行训练。你可以从他们的网站下载公开可用的不规则掩码数据集。
或者,你可以下载Karim Iskakov提供的Quick Draw不规则掩码数据集,该数据集由人手绘制的五千万笔画组成。
请使用scripts/flist.py
来生成训练、测试和验证掩码文件列表,如上所述。
入门指南
使用以下链接下载预训练模型,并将它们复制到./checkpoints
目录下。
Places2 | CelebA | Paris-StreetView
或者,你可以运行以下脚本来自动下载预训练模型:
bash ./scripts/download_model.sh
1) 训练
要训练模型,请创建一个类似于示例配置文件的config.yaml
文件,并将其复制到你的检查点目录下。阅读配置指南了解更多关于模型配置的信息。
EdgeConnect的训练分为三个阶段:1) 训练边缘模型,2) 训练修补模型,3) 训练联合模型。要训练模型:
python train.py --model [阶段] --checkpoints [检查点路径]
例如,要在Places2数据集上训练边缘模型,将其保存在./checkpoints/places2
目录下:
python train.py --model 1 --checkpoints ./checkpoints/places2
不同数据集的模型收敛时间不同。例如,Places2数据集在一到两个周期内收敛,而较小的数据集如CelebA则需要大约40个周期才能收敛。你可以通过更改配置文件中的MAX_ITERS
值来设置训练迭代次数。
2) 测试
要测试模型,请创建一个类似于示例配置文件的config.yaml
文件,并将其复制到你的检查点目录下。阅读配置指南了解更多关于模型配置的信息。
你可以在所有三个阶段测试模型:1) 边缘模型,2) 修补模型,3) 联合模型。在每种情况下,你需要提供一个输入图像(带掩码的图像)和一个灰度掩码文件。请确保掩码文件覆盖输入图像中的整个掩码区域。要测试模型:
python test.py \
--model [阶段] \
--checkpoints [检查点路径] \
--input [输入目录或文件路径] \
--mask [掩码目录或掩码文件路径] \
--output [输出目录路径]
我们在./examples
目录下提供了一些测试示例。请下载预训练模型并运行:
python test.py \
--checkpoints ./checkpoints/places2
--input ./examples/places2/images
--mask ./examples/places2/masks
--output ./checkpoints/results
此脚本将使用./examples/places2/mask
目录中的相应掩码修补./examples/places2/images
中的所有图像,并将结果保存到./checkpoints/results
目录中。默认情况下,test.py
脚本在阶段3运行(--model=3
)。
3) 评估
要评估模型,你首先需要在测试模式下运行模型针对你的验证集,并将结果保存在磁盘。我们提供一个工具./scripts/metrics.py
使用PSNR、SSIM和平均绝对误差评估模型:
python ./scripts/metrics.py --data-path [验证集路径] --output-path [模型输出路径]
要测量Fréchet Inception Distance(FID评分),运行./scripts/fid_score.py
。我们使用了来自这里的PyTorch实现,它使用了PyTorch的Inception模型的预训练权重。
python ./scripts/fid_score.py --path [验证集路径, 模型输出路径] --gpu [使用的GPU id]
可选的边缘检测
默认情况下,我们使用Canny边缘检测器从输入图像中提取边缘信息。如果您想用外部边缘检测(例如Holistically-Nested Edge Detection)训练模型,您需要为整个训练/测试集生成边缘图作为预处理,并使用scripts/flist.py
生成相应的文件列表,如上所述。请确保文件名和目录结构与您的训练/测试集匹配。您可以通过在配置文件中指定EDGE=2
切换到外部边缘检测。
模型配置
模型配置存储在检查点目录下的config.yaml
文件中。以下表格提供了配置文件中所有可用选项的文档:
一般模型配置
选项 | 描述 |
---|---|
MODE | 1: 训练, 2: 测试, 3: 评估 |
MODEL | 1: 边缘模型, 2: 修复模型, 3: 边缘修复模型, 4: 联合模型 |
MASK | 1: 随机块, 2: 一半, 3: 外部, 4: 外部 + 随机块, 5: 外部 + 随机块 + 一半 |
EDGE | 1: canny, 2: 外部 |
NMS | 0: 无非最大抑制, 1: 对外部边缘进行非最大抑制 |
SEED | 随机数生成器种子 |
GPU | GPU id 列表,用逗号分隔,如 [0,1] |
DEBUG | 0: 无调试, 1: 调试模式 |
VERBOSE | 0: 无详细信息, 1: 在输出控制台输出详细统计信息 |
训练、测试和验证集加载配置
选项 | 描述 |
---|---|
TRAIN_FLIST | 包含训练集文件列表的文本文件 |
VAL_FLIST | 包含验证集文件列表的文本文件 |
TEST_FLIST | 包含测试集文件列表的文本文件 |
TRAIN_EDGE_FLIST | 包含训练集外部边缘文件列表的文本文件(仅在EDGE=2时) |
VAL_EDGE_FLIST | 包含验证集外部边缘文件列表的文本文件(仅在EDGE=2时) |
TEST_EDGE_FLIST | 包含测试集外部边缘文件列表的文本文件(仅在EDGE=2时) |
TRAIN_MASK_FLIST | 包含训练集掩码文件列表的文本文件(仅在MASK=3, 4, 5时) |
VAL_MASK_FLIST | 包含验证集掩码文件列表的文本文件(仅在MASK=3, 4, 5时) |
TEST_MASK_FLIST | 包含测试集掩码文件列表的文本文件(仅在MASK=3, 4, 5时) |
训练模式配置
选项 | 默认值 | 描述 |
---|---|---|
LR | 0.0001 | 学习率 |
D2G_LR | 0.1 | 判别器/生成器学习率比率 |
BETA1 | 0.0 | adam优化器beta1 |
BETA2 | 0.9 | adam优化器beta2 |
BATCH_SIZE | 8 | 输入批量大小 |
INPUT_SIZE | 256 | 训练的输入图像大小(0为原始大小) |
SIGMA | 2 | Canny边缘检测器中使用的高斯滤波器的标准差 (0: 随机, -1: 无边缘) |
MAX_ITERS | 2e6 | 最大训练迭代次数 |
EDGE_THRESHOLD | 0.5 | 边缘检测阈值(0-1) |
L1_LOSS_WEIGHT | 1 | l1损失权重 |
FM_LOSS_WEIGHT | 10 | 特征匹配损失权重 |
STYLE_LOSS_WEIGHT | 1 | 风格损失权重 |
CONTENT_LOSS_WEIGHT | 1 | 感知损失权重 |
INPAINT_ADV_LOSS_WEIGHT | 0.01 | 对抗损失权重 |
GAN_LOSS | nsgan | nsgan: 非饱和性GAN, lsgan: 最小二乘GAN, hinge: hinge loss GAN |
GAN_POOL_SIZE | 0 | 假图像池大小 |
SAVE_INTERVAL | 1000 | 保存模型前需等待的迭代次数(0: 不保存) |
EVAL_INTERVAL | 0 | 评估模型前需等待的迭代次数(0: 不评估) |
LOG_INTERVAL | 10 | 记录训练损失前需等待的迭代次数(0: 不记录) |
SAMPLE_INTERVAL | 1000 | 保存样本前需等待的迭代次数(0: 不保存) |
SAMPLE_SIZE | 12 | 每次采样间隔的样本数量 |
许可证
根据知识共享署名-非商业性使用 4.0 国际许可证授权。
除非另有说明,否则此内容根据CC BY-NC许可证发布,这意味着您可以复制、混合、转换和基于该内容进行创作,前提是不用于商业目的并给予适当的署名和提供许可证链接。
引用
如果您在研究中使用此代码,请引用我们的论文EdgeConnect: Generative Image Inpainting with Adversarial Edge Learning或EdgeConnect: Structure Guided Image Inpainting using Edge Prediction:
@inproceedings{nazeri2019edgeconnect,
title={EdgeConnect: Generative Image Inpainting with Adversarial Edge Learning},
author={Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran},
journal={arXiv preprint},
year={2019},
}
@InProceedings{Nazeri_2019_ICCV,
title = {EdgeConnect: Structure Guided Image Inpainting using Edge Prediction},
author = {Nazeri, Kamyar and Ng, Eric and Joseph, Tony and Qureshi, Faisal and Ebrahimi, Mehran},
booktitle = {The IEEE International Conference on Computer Vision (ICCV) Workshops},
month = {Oct},
year = {2019}
}