TensorFlow和PyTorch中的变分自编码器
TensorFlow和PyTorch中变分自编码器的参考实现。
我推荐使用PyTorch版本。它包含了一个更具表现力的变分族的示例,即反自回归流。
变分推断用于将模型拟合到二值化的MNIST手写数字图像。使用推断网络(编码器)来分摊推断并在数据点之间共享参数。似然由生成网络(解码器)参数化。
博客文章:https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
PyTorch实现
(anaconda环境在environment-jax.yml
中)
重要性采样用于估计Hugo Larochelle的二值MNIST数据集上的边际似然。测试集上的最终边际似然为-97.10
奈特,与已发表的数字相当。
$ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000
第0步 训练ELBO估计:-558.027 验证ELBO估计:-384.432 验证log p(x)估计:-355.430 速度:2.72e+06 样本/秒
第10000步 训练ELBO估计:-111.323 验证ELBO估计:-109.048 验证log p(x)估计:-103.746 速度:2.64e+04 样本/秒
第20000步 训练ELBO估计:-103.013 验证ELBO估计:-107.655 验证log p(x)估计:-101.275 速度:2.63e+04 样本/秒
第29999步 测试ELBO估计:-106.642 测试log p(x)估计:-100.309
总时间:2.49分钟
使用非平均场、更具表现力的变分后验近似(反自回归流,https://arxiv.org/abs/1606.04934),测试边际对数似然改善到`-95.33`奈特:
$ python train_variational_autoencoder_pytorch.py --variational flow
步骤:0 训练elbo:-578.35
步骤:0 验证elbo:-407.06 验证log p(x):-367.88
步骤:10000 训练elbo:-106.63
步骤:10000 验证elbo:-110.12 验证log p(x):-104.00
步骤:20000 训练elbo:-101.51
步骤:20000 验证elbo:-105.02 验证log p(x):-99.11
步骤:30000 训练elbo:-98.70
步骤:30000 验证elbo:-103.76 验证log p(x):-97.71
jax实现
使用jax(anaconda环境在environment-jax.yml
中),比pytorch快3倍:
$ python train_variational_autoencoder_jax.py --variational mean-field
第0步 训练ELBO估计:-566.059 验证ELBO估计:-565.755 验证log p(x)估计:-557.914 速度:2.56e+11 样本/秒
第10000步 训练ELBO估计:-98.560 验证ELBO估计:-105.725 验证log p(x)估计:-98.973 速度:7.03e+04 样本/秒
第20000步 训练ELBO估计:-109.794 验证ELBO估计:-105.756 验证log p(x)估计:-97.914 速度:4.26e+04 样本/秒
第29999步 测试ELBO估计:-104.867 测试log p(x)估计:-96.716
总时间:0.810分钟
jax中的反自回归流:
$ python train_variational_autoencoder_jax.py --variational flow
第0步 训练ELBO估计:-727.404 验证ELBO估计:-726.977 验证log p(x)估计:-713.389 速度:2.56e+11 样本/秒
第10000步 训练ELBO估计:-100.093 验证ELBO估计:-106.985 验证log p(x)估计:-99.565 速度:2.57e+04 样本/秒
第20000步 训练ELBO估计:-113.073 验证ELBO估计:-108.057 验证log p(x)估计:-98.841 速度:3.37e+04 样本/秒
第29999步 测试ELBO估计:-106.803 测试log p(x)估计:-97.620
总时间:2.350分钟
(平均场和反自回归流之间的差异可能由几个因素造成,主要是实现中缺少卷积。https://arxiv.org/pdf/1606.04934.pdf 中使用了残差块,使ELBO更接近-80奈特。)
生成GIF
- 运行
python train_variational_autoencoder_tensorflow.py
- 安装imagemagick(Mac用homebrew:https://formulae.brew.sh/formula/imagemagick 或Windows用Chocolatey:https://community.chocolatey.org/packages/imagemagick.app)
- 进入保存jpg文件的目录,运行imagemagick命令生成.gif:
convert -delay 20 -loop 0 *.jpg latent-space.gif
待办事项(需要帮助 - 欢迎发送PR!)
- 添加多GPU / TPU选项
- 为PyTorch和Jax实现添加jaxtyping支持:)以进行运行时静态类型检查(使用@beartype装饰器)