PyTorch VAE
2021/12/22更新: 增加了对PyTorch Lightning 1.5.6版本的支持,并清理了代码。
一个使用PyTorch实现的变分自编码器(VAE)集合,重点在于可重复性。此项目旨在为许多流行的VAE模型提供一个快速且简单的工作示例。所有模型都在 CelebA 数据集 上进行训练,以保持一致性和比较。所有模型的架构尽可能保持相似,使用相同的层,除非原始论文需要完全不同的架构(例如,VQ VAE使用残差层且不使用Batch-Norm,这与其他模型不同)。以下是每个模型的结果。
要求
- Python >= 3.5
- PyTorch >= 1.3
- Pytorch Lightning >= 0.6.0 (GitHub Repo)
- 支持CUDA的计算设备
安装
$ git clone https://github.com/AntixK/PyTorch-VAE
$ cd PyTorch-VAE
$ pip install -r requirements.txt
使用
$ cd PyTorch-VAE
$ python run.py -c configs/<config-file-name.yaml>
配置文件模板
model_params:
name: "<VAE模型名称>"
in_channels: 3
latent_dim:
. # 模型所需的其他参数
.
.
data_params:
data_path: "<celebA数据集的路径>"
train_batch_size: 64 # 最好是平方数
val_batch_size: 64
patch_size: 64 # 模型设计为在此尺寸下工作
num_workers: 4
exp_params:
manual_seed: 1265
LR: 0.005
weight_decay:
. # 训练所需的其他参数,例如调度器等。
.
.
trainer_params:
gpus: 1
max_epochs: 100
gradient_clip_val: 1.5
.
.
.
logging_params:
save_dir: "logs/"
name: "<实验名称>"
查看TensorBoard日志
$ cd logs/<experiment name>/version_<the version you want>
$ tensorboard --logdir .
注意: 默认数据集为CelebA。然而,由于谷歌云盘上的文件结构更改,下载该数据集时存在许多问题。因此,建议直接从谷歌云盘下载文件,并解压到您选择的路径中。配置文件中默认的路径为Data/celeba/img_align_celeba
。但您可以根据自己的喜好进行更改。
结果
模型 | 论文 | 重建效果 | 样本 |
---|
VAE ([代码][vae_code], [配置][vae_config]) | 链接 | | |
Conditional VAE ([代码][cvae_code], [配置][cvae_config]) | 链接 | | |
WAE - MMD (RBF Kernel) ([代码][wae_code], [配置][wae_rbf_config]) | 链接 | | |
WAE - MMD (IMQ Kernel) ([代码][wae_code], [配置][wae_imq_config]) | 链接 | | |
Beta-VAE ([代码][bvae_code], [配置][bbvae_config]) | 链接 | | |
Disentangled Beta-VAE ([代码][bvae_code], [配置][bhvae_config]) | 链接 | | |
Beta-TC-VAE ([代码][btcvae_code], [配置][btcvae_config]) | 链接 | | |
IWAE (K = 5) ([代码][iwae_code], [配置][iwae_config]) | 链接 | | |
MIWAE (K = 5, M = 3) ([代码][miwae_code], [配置][miwae_config]) | 链接 | | |
DFCVAE ([代码][dfcvae_code], [配置][dfcvae_config]) | 链接 | | |
MSSIM VAE ([代码][mssimvae_code], [配置][mssimvae_config]) | 链接 | | |
Categorical VAE ([代码][catvae_code], [配置][catvae_config]) | 链接 | | |
Joint VAE ([代码][jointvae_code], [配置][jointvae_config]) | 链接 | | |
Info VAE ([代码][infovae_code], [配置][infovae_config]) | 链接 | | |
LogCosh VAE ([代码][logcoshvae_code], [配置][logcoshvae_config]) | 链接 | | |
SWAE (200 Projections) ([代码][swae_code], [配置][swae_config]) | 链接 | | |
VQ-VAE (K = 512, D = 64) ([代码][vqvae_code], [配置][vqvae_config]) | 链接 | | **N/A |
[SOURCE_TEXT]: [vae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vae.yaml | | | |
[cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml | | | |
[bbvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bbvae.yaml | | | |
[bhvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bhvae.yaml | | | |
[btcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/betatc_vae.yaml | | | |
[wae_rbf_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_rbf.yaml | | | |
[wae_imq_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_imq.yaml | | | |
[iwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/iwae.yaml | | | |
[miwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/miwae.yaml | | | |
[swae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/swae.yaml | | | |
[jointvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/joint_vae.yaml | | | |
[dfcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dfc_vae.yaml | | | |
[mssimvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/mssim_vae.yaml | | | |
[logcoshvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/logcosh_vae.yaml | | | |
[catvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cat_vae.yaml | | | |
[infovae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/infovae.yaml | | | |
[vqvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vq_vae.yaml | | | |
[dipvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dip_vae.yaml | | | |