无监督对比翻译 (CUT)
视频 (1分钟) | 视频 (10分钟) | 网站 | 论文
我们提供了基于Patchwise对比学习和对抗学习的无监督图像到图像翻译的PyTorch实现。不使用手工设计的损失和逆网络。与CycleGAN相比,我们的模型训练更快,占用的内存更少。此外,我们的方法可以扩展到单图像训练,其中每个“域”仅为一个单独的图像。
对比学习用于无监督图像到图像翻译
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
伯克利大学和Adobe研究所
在ECCV 2020
伪代码
import torch
cross_entropy_loss = torch.nn.CrossEntropyLoss()
# 输入: f_q (BxCxS) 和 从 H(G_enc(x)) 中采样的特征
# 输入: f_k (BxCxS) 是从 H(G_enc(G(x)) 中采样的特征
# 输入: tau 是PatchNCE损失中使用的温度
# 输出: PatchNCE损失
def PatchNCELoss(f_q, f_k, tau=0.07):
# 批量大小、通道大小和采样位置的数量
B, C, S = f_q.shape
# 计算 v * v+: BxSx1
l_pos = (f_k * f_q).sum(dim=1)[:, :, None]
# 计算 v * v-: BxSxS
l_neg = torch.bmm(f_q.transpose(1, 2), f_k)
# 对角条目不为负。去掉它们
identity_matrix = torch.eye(S)[None, :, :]
l_neg.masked_fill_(identity_matrix, -float('inf'))
# 计算logits: (B)x(S)x(S+1)
logits = torch.cat((l_pos, l_neg), dim=2) / tau
# 返回PatchNCE损失
predictions = logits.flatten(0, 1)
targets = torch.zeros(B * S, dtype=torch.long)
return cross_entropy_loss(predictions, targets)
示例结果
无监督图像到图像翻译
单图像无监督翻译
从俄蓝猫到愤怒猫
从巴黎街道到布拉诺彩色房屋
前提条件
- Linux 或 macOS
- Python 3
- CPU 或 NVIDIA GPU + CUDA CuDNN
更新日志
2020年9月12日: 添加了单图像翻译。
入门
- 克隆这个代码仓库:
git clone https://github.com/taesungp/contrastive-unpaired-translation CUT
cd CUT
-
安装 PyTorch 1.1 和其他依赖项 (例如 torchvision, visdom, dominate, gputil)。
对于 pip 用户,请输入命令
pip install -r requirements.txt
。对于 Conda 用户,你可以使用
conda env create -f environment.yml
创建一个新的 Conda 环境。
CUT 和 FastCUT 训练和测试
- 下载
grumpifycat
数据集 (论文图8. 俄罗斯蓝猫 -> 愤怒猫)
bash ./datasets/download_cut_dataset.sh grumpifycat
数据集会下载并解压到 ./datasets/grumpifycat/
。
-
要查看训练结果和损失图,请运行
python -m visdom.server
并点击网址 http://localhost:8097。 -
训练 CUT 模型:
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT
或训练 FastCUT 模型
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_FastCUT --CUT_mode FastCUT
模型检查点会存储在 ./checkpoints/grumpycat_*/web
。
- 测试 CUT 模型:
python test.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT --phase train
测试结果会保存到这里的html文件:./results/grumpifycat/latest_train/index.html
。
CUT, FastCUT 和 CycleGAN
CUT 训练时使用了身份保持损失且 lambda_NCE=1
, 而 FastCUT 在训练时没有身份损失但有较高的 lambda_NCE=10.0
。与 CycleGAN 相比,CUT 学习了更强大的分布匹配,而 FastCUT 作为一种更轻量(仅需一半GPU内存,可适配更大图像)、更快(训练速度是CycleGAN的两倍)的替代品被设计出来。有关更多细节,请参考论文。
在上图中,我们使用预训练的语义分割模型测量属于马/斑马身体的像素比例。我们发现马和斑马图像大小之间存在分布不匹配——斑马通常显得更大(36.8% vs. 17.9%)。我们的完整方法CUT有灵活性来放大马匹,以更好匹配训练统计数据,而FastCUT的行为更保守,类似于CycleGAN。
使用我们的启动脚本进行训练
请参见 experiments/grumpifycat_launcher.py
,该脚本生成上述命令行参数。启动脚本对于配置相当复杂的训练和测试命令行参数非常有用。
使用启动脚本,下面的命令生成CUT和FastCUT的训练命令。
python -m experiments grumpifycat train 0 # CUT
python -m experiments grumpifycat train 1 # FastCUT
使用启动脚本进行测试,
python -m experiments grumpifycat test 0 # CUT
python -m experiments grumpifycat test 1 # FastCUT
可能的命令有 run, run_test, launch, close 等等。请参见 experiments/__main__.py
获取所有命令。启动脚本易于定义和使用。例如,grumpifycat 启动器仅能用几行来定义:
from .tmux_launcher import Options, TmuxLauncher
class Launcher(TmuxLauncher):
def common_options(self):
return [
Options( # Command 0
dataroot="./datasets/grumpifycat",
name="grumpifycat_CUT",
CUT_mode="CUT"
),
Options( # Command 1
dataroot="./datasets/grumpifycat",
name="grumpifycat_FastCUT",
CUT_mode="FastCUT",
)
]
def commands(self):
return ["python train.py " + str(opt) for opt in self.common_options()]
def test_commands(self):
# 俄罗斯蓝猫->愤怒猫数据集没有测试分割。
# 因此,让我们将测试分割设置为"train"集。
return ["python test.py " + str(opt.set(phase='train')) for opt in self.common_options()]
应用预训练的CUT模型并评估FID
要运行预训练模型,请运行以下命令。
# 下载并解压预训练模型。权重应位于
# checkpoints/horse2zebra_cut_pretrained/latest_net_G.pth,例如。
wget http://efrosgans.eecs.berkeley.edu/CUT/pretrained_models.tar
tar -xf pretrained_models.tar
# 生成输出。可能需要调整数据集路径。
# 为此,请修改experiments/pretrained_launcher.py的行
# [id] 对应于 pretrained_launcher.py 中定义的相应命令
# 0 - CUT 于 Cityscapes
# 1 - FastCUT 于 Cityscapes
# 2 - CUT 于 Horse2Zebra
# 3 - FastCUT 于 Horse2Zebra
# 4 - CUT 于 Cat2Dog
# 5 - FastCUT 于 Cat2Dog
python -m experiments pretrained run_test [id]
评估FID。要执行此操作,请首先安装pytorch-fid,网址:https://github.com/mseitzer/pytorch-fid
pip install pytorch-fid
例如,要评估CUT的horse2zebra FID,
python -m pytorch_fid ./datasets/horse2zebra/testB/ results/horse2zebra_cut_pretrained/test_latest/images/fake_B/
要评估FastCUT的Cityscapes FID,
python -m pytorch_fid ./datasets/cityscapes/valA/ ~/projects/contrastive-unpaired-translation/results/cityscapes_fastcut_pretrained/test_latest/images/fake_B/
请注意,Cityscapes模型需要使用特殊的数据集。请阅读下文。
python -m pytorch_fid [真实测试图像路径] [生成图像路径]
注意:预训练的Cityscapes模型是在原始Cityscapes数据集的一个调整大小且JPEG压缩版本上训练和评估的。要进行评估,请下载[此](http://efrosgans.eecs.berkeley.edu/CUT/datasets/cityscapes_val_for_CUT.tar)验证集并执行评估。
### SinCUT单图像无配对训练
要训练SinCUT(单图像翻译,如图9、13和14所示),您需要:
1. 设置`--model`选项为`--model sincut`,这会调用`./models/sincut_model.py`中的配置和代码,并且
2. 指定每个领域中一个图像的数据集目录,例如此存储库中包含的示例数据集位于`./datasets/single_image_monet_etretat/`。
例如,要训练一个用于[Etretat悬崖(图13的第一张图片)](https://github.com/taesungp/contrastive-unpaired-translation/blob/master/imgs/singleimage.gif)的模型,请使用以下命令。
```bash
python train.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat
或者使用实验启动脚本,
python -m experiments singleimage run 0
对于单图像翻译,我们采用了StyleGAN2的网络架构组件,以及DTN和CycleGAN中使用的像素身份保存损失。具体来说,我们采用了rosinality的代码,存放于models/stylegan_networks.py
。
训练需要几个小时。要使用检查点生成最终图像,
python test.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat
或者简单
python -m experiments singleimage run_test 0
数据集
下载CUT/CycleGAN/pix2pix数据集。例如,
bash datasets/download_cut_datasets.sh horse2zebra
Cat2Dog数据集是从AFHQ数据集准备的。请访问https://github.com/clovaai/stargan-v2并通过此github存储库中的`bash download.sh afhq-dataset`下载AFHQ数据集。然后按以下方式重组目录。
mkdir datasets/cat2dog
ln -s datasets/cat2dog/trainA [afhq路径]/train/cat
ln -s datasets/cat2dog/trainB [afhq路径]/train/dog
ln -s datasets/cat2dog/testA [afhq路径]/test/cat
ln -s datasets/cat2dog/testB [afhq路径]/test/dog
可以从https://cityscapes-dataset.com下载Cityscapes数据集。
之后,使用脚本./datasets/prepare_cityscapes_dataset.py
准备数据集。
输入图像的预处理
输入图像的预处理(如调整大小或随机裁剪)由--preprocess
、--load_size
和--crop_size
选项控制。用法遵循CycleGAN/pix2pix存储库。
例如,默认设置--preprocess resize_and_crop --load_size 286 --crop_size 256
将输入图像调整为286x286
,然后随机裁剪大小为256x256
,作为一种数据增强方式。还可以指定其他预处理选项,详见base_dataset.py。以下是一些示例选项。
--preprocess none
: 不执行任何预处理。请注意,图像大小仍将缩放为最接近的4的倍数,因为卷积生成器无法保持相同的图像大小。--preprocess scale_width --load_size 768
: 将图像的宽度缩放为768。--preprocess scale_shortside_and_crop
: 在保持纵横比的前提下,将图像缩放到短边为load_size
,然后随机裁剪窗口大小为crop_size
。
可以通过修改get_transform()
的base_dataset.py
添加更多预处理选项。
引用
如果您将此代码用于您的研究,请引用我们的论文。
@inproceedings{park2020cut,
title={Contrastive Learning for Unpaired Image-to-Image Translation},
author={Taesung Park and Alexei A. Efros and Richard Zhang and Jun-Yan Zhu},
booktitle={European Conference on Computer Vision},
year={2020}
}
如果您使用了此存储库中包含的原始pix2pix和CycleGAN模型,请引用以下论文
@inproceedings{CycleGAN2017,
title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},
author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
booktitle={IEEE International Conference on Computer Vision (ICCV)},
year={2017}
}
@inproceedings{isola2017image,
title={Image-to-Image Translation with Conditional Adversarial Networks},
author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2017}
}
感谢
感谢Allan Jabri和Phillip Isola提供有益的讨论和反馈。我们的代码基于pytorch-CycleGAN-and-pix2pix开发。感谢pytorch-fid用于FID计算,感谢drn用于mIoU计算,感谢stylegan2-pytorch用于我们单图像翻译设置中使用的StyleGAN2的PyTorch实现。