贝叶斯流网络
这是 Alex Graves、Rupesh Kumar Srivastava、Timothy Atkinson 和 Faustino Gomez 发表的贝叶斯流网络论文的官方代码发布。
阅读指南
model.py
包含了论文的所有主要贡献。其中包括连续和离散数据的贝叶斯流定义,以及连续时间和离散时间的损失函数。详细信息请参见该文件中基类的注释。probability.py
定义了模型使用的概率分布。train.py
、test.py
和sample.py
是用于训练、测试和采样的脚本(使用说明见下文)。data.py
包含与数据加载和处理相关的实用工具。networks/
包含模型使用的网络架构实现。
环境设置
# 创建一个包含所有依赖项(包括pytorch和CUDA)的新conda环境
conda env create -f env.yml
conda activate bfn
# 或者,在现有的pytorch环境中安装额外的依赖项
pip install accelerate==0.19.0 matplotlib omegaconf rich
# 可选,如果你想启用neptune.ai日志记录
pip install neptune
训练
论文中的模型可以使用configs
目录中提供的配置文件进行训练,方法如下:
# 在1个GPU上进行mnist实验
accelerate launch train.py config_file=configs/mnist_discrete.yaml
# 在1个GPU(A100)上进行cifar10实验
accelerate launch train.py config_file=configs/cifar10_discretized_256bins.yaml
# 在8个GPU(A100)上进行text8实验
accelerate launch --multi_gpu --num_processes=8 --num_machines=1 --dynamo_backend=no --mixed_precision=fp16 train.py config_file=configs/text8_discrete.yaml
测试
[!注意] 根据你的GPU,你可能需要调整
test.py
中用于测试的批量大小。
# 可选:下载预训练的检查点(确保你已安装git-lfs:https://git-lfs.com/)
git clone git@hf.co:rupspace/pretrained-BFNs
# 计算MNIST的784步损失
python test.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt n_steps=784 n_repeats=2000
# 计算CIFAR-10的10步损失
python test.py seed=1 config_file=./configs/cifar10_discretized_256bins.yaml load_model=./pretrained-BFNs/cifar10_256d_ema.pt n_steps=10 n_repeats=100
# 计算text8的连续时间损失
python test.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt n_steps=0 n_repeats=1
[!重要] 所有计算结果将以每数据维度的nats为单位。要转换为比特,请除以ln(2)。
采样
你可以按如下方式从预训练模型中进行采样(根据需要更改选项):
# 使用100步采样4张二值化MNIST图像
python sample.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt samples_shape="[4, 28, 28, 1]" n_steps=100 save_file=./samples_mnist.pt
# 使用1000步采样4张CIFAR-10 16位图像,作为离散化数据建模
python sample.py seed=1 config_file=./configs/cifar10_discretized_16bins.yaml load_model=./pretrained-BFNs/cifar10_16d_ema.pt samples_shape="[4, 32, 32, 3]" n_steps=1000 save_file=./samples_cifar.pt
# 使用100步采样2个长度为256的text8序列
python sample.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt samples_shape="[2, 256]" n_steps=100 save_file=./samples_text8.pt
采样结果作为PyTorch张量存储在save_file
中,可以通过加载它们并使用data.py
中的batch_to_images
和batch_to_str
实用工具进行可视化。
例如:
# batch_to_images返回一个matplotlib Figure对象
python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_mnist.pt')).savefig('mnist.png')"
python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_cifar.pt')).savefig('cifar.png')"
# batch_to_str返回一个str列表
python -c "import torch; from data import batch_to_str; print(batch_to_str(torch.load('./samples_text8.pt')))"
可重复性
如果需要高度的可重复性(例如在采样过程中),请设置以下内容:
torch.set_float32_matmul_precision("highest")
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
致谢
我们感谢@Higgcz对实验基础设施和代码发布的慷慨支持。