描述
条件流匹配(CFM)是一种快速训练连续标准化流(CNF)模型的方法。CFM是一种无需模拟的连续标准化流训练目标,它允许条件生成建模并加速训练和推理。CFM的性能缩小了CNF和扩散模型之间的差距。为了在机器学习社区中推广其使用,我们构建了一个专注于流匹配方法的库:TorchCFM。TorchCFM是一个展示如何训练和使用流匹配方法来处理图像生成、单细胞动力学、表格数据以及即将推出的SO(3)数据的库。
无模拟CNF训练方案的密度、向量场和轨迹:将8个高斯分布映射到两个月牙形(上图)和将单个高斯分布映射到两个月牙形(下图)。使用相同架构(3x64 MLP with SeLU激活)的动作匹配在ReLU、SiLU和SiLU激活下欠拟合,如示例代码所示,但在我们的训练设置下(Action-Matching (Swish))似乎拟合得更好。
生成GIF的模型存储在examples/models
中,可以使用此笔记本进行可视化:。
我们还在examples/notebooks/mnist_example.ipynb
中包含了一个无条件MNIST生成的示例,包括确定性和随机生成。。
torchcfm包
在我们的版本1更新中,我们将相关流匹配变体的实现提取到了torchcfm
包中。这允许抽象条件分布q(z)
的选择。torchcfm
提供以下损失函数:
ConditionalFlowMatcher
:$z = (x_0, x_1)$,$q(z) = q(x_0) q(x_1)$ExactOptimalTransportConditionalFlowMatcher
:$z = (x_0, x_1)$,$q(z) = \pi(x_0, x_1)$,其中$\pi$是精确的最优传输联合分布。这在[Tong et al. 2023a]和[Poolidan et al. 2023]中被称为"OT-CFM"和"带批量OT的多样本FM"。TargetConditionalFlowMatcher
:$z = x_1$,$q(z) = q(x_1)$,如Lipman et al. 2023中定义,学习从标准正态高斯分布到数据的流,使用条件流将高斯分布最优传输到数据点(注意这不会导致边际流成为最优传输)。SchrodingerBridgeConditionalFlowMatcher
:$z = (x_0, x_1)$,$q(z) = \pi_\epsilon(x_0, x_1)$,其中$\pi_\epsilon$是熵正则化的OT计划,尽管在实践中这通常通过小批量OT计划来近似(参见Tong et al. 2023b)。边际等价于薛定谔桥边际的流匹配变体称为SB-CFM
[Tong et al. 2023a]。当分数也已知且桥是随机的时称为[SF]2M [Tong et al. 2023b]。VariancePreservingConditionalFlowMatcher
:$z = (x_0, x_1)$ $q(z) = q(x_0) q(x_1)$,但使用条件高斯概率路径,通过三角插值在时间上保持方差,如[Albergo et al. 2023a]中所述。
如何引用
此存储库包含重现两个预印本的主要实验和图示的代码:
- 通过小批量最优传输改进和推广基于流的生成模型。我们引入了最优传输条件流匹配(OT-CFM),这是一种近似最优传输(OT)动力学公式的CFM变体。基于OT理论,OT-CFM利用静态最优传输计划以及最优概率路径和向量场来近似动态OT。
- 通过分数和流匹配的无模拟薛定谔桥。我们提出了无模拟分数和流匹配([SF]2M)。[SF]2M利用OT-CFM以及基于分数的方法来近似薛定谔桥,这是最优传输的一种随机版本。
如果您在研究中发现此代码有用,请引用以下论文(展开获取BibTeX):
A. Tong, N. Malkin, G. Huguet, Y. Zhang, J. Rector-Brooks, K. Fatras, G. Wolf, Y. Bengio. 通过小批量最优传输改进和推广基于流的生成模型, 2023.
@article{tong2024improving,
title={Improving and generalizing flow-based generative models with minibatch optimal transport},
author={Alexander Tong and Kilian FATRAS and Nikolay Malkin and Guillaume Huguet and Yanlei Zhang and Jarrid Rector-Brooks and Guy Wolf and Yoshua Bengio},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2024},
url={https://openreview.net/forum?id=CD9Snc73AW},
note={Expert Certification}
}
A. Tong, N. Malkin, K. Fatras, L. Atanackovic, Y. Zhang, G. Huguet, G. Wolf, Y. Bengio. 通过分数和流匹配的无模拟薛定谔桥, 2023.
@article{tong2023simulation,
title={Simulation-Free Schr{\"o}dinger Bridges via Score and Flow Matching},
author={Tong, Alexander and Malkin, Nikolay and Fatras, Kilian and Atanackovic, Lazar and Zhang, Yanlei and Huguet, Guillaume and Wolf, Guy and Bengio, Yoshua},
year={2023},
journal={arXiv preprint 2307.03672}
}
V0 -> V1
主要变更:
- 添加了FID为3.5的cifar10示例
- 添加了新的无模拟分数和流匹配(SF)2M预印本的代码
- 创建了
torchcfm
pip可安装包 - 将
pytorch-lightning
实现和实验移至runner
目录 - 将
notebooks
->examples
- 在lightning和
examples
中的笔记本中添加了图像生成实现
已实现的论文
已实现论文列表:
- 用于生成建模的流匹配(Lipman等,2023年)论文
- 流动直接且快速:学习生成和传输数据的校正流(Liu等,2023年)论文 代码
- 利用随机插值构建归一化流(Albergo等,2023年a)论文
- 行动匹配:从样本中学习随机动力学(Neklyudov等,2022年)论文 代码
- 与我们的OT-CFM方法同期的工作:多样本流匹配:使用小批量耦合拉直流(Pooladian等,2023年)论文
- 通过扩散和基于流的梯度提升树生成和填补表格数据(Jolicoeur-Martineau等)论文 代码
- 即将发布:用于蛋白质骨架生成的SE(3)-随机流匹配(Bose等)论文
如何运行
在这里运行一个简单的最小示例 。或者按照以下步骤在本地安装更高效的代码。
TorchCFM现已在pypi上发布!你可以通过以下命令安装:
pip install torchcfm
要使用完整的库及不同示例,你可以安装依赖项:
# 克隆项目
git clone https://github.com/atong01/conditional-flow-matching.git
cd conditional-flow-matching
# [可选] 创建conda环境
conda create -n torchcfm python=3.10
conda activate torchcfm
# 按照说明安装pytorch
# https://pytorch.org/get-started/
# 安装requirements
pip install -r requirements.txt
# 安装torchcfm
pip install -e .
安装我们的包后,使用以下命令运行我们的jupyter notebooks。
# 安装ipykernel
conda install -c anaconda ipykernel
# 在jupyter notebook中安装conda环境
python -m ipykernel install --user --name=torchcfm
# 使用torchcfm内核启动我们的notebooks
项目结构
目录结构如下:
│
├── examples <- Jupyter notebooks
| ├── cifar10 <- Cifar10实验
│ ├── notebooks <- 各种示例notebooks
│
│── runner <- 与库原始版本(V0)相关的所有内容
│
|── torchcfm <- 我们的流匹配方法的代码库
| ├── conditional_flow_matching.py <- CFM类
│ ├── models <- 模型架构
│ │ ├── models <- 2D示例的模型
│ │ ├── Unet <- 图像示例的Unet模型
|
├── .gitignore <- git忽略的文件列表
├── .pre-commit-config.yaml <- 代码格式化的预提交钩子配置
├── pyproject.toml <- 测试和代码检查的配置选项
├── requirements.txt <- 安装python依赖项的文件
├── setup.py <- 将项目作为包安装的文件
└── README.md
❤️ 代码贡献
这个工具箱由以下人员创建和维护:
它最初源自一个更大的私有代码库,因此失去了原始提交历史,其中包含论文其他作者的工作。
在提出问题之前,请确认:
- 该问题在当前的
main
分支上仍然存在。 - 你的python依赖项已更新到最新版本。
我们随时欢迎改进建议!
赞助商
TorchCFM的开发和维护得到了以下机构的财务支持:
许可证
Conditional-Flow-Matching采用MIT许可证。
MIT许可证
版权所有 (c) 2023 Alexander Tong
特此免费授予任何获得本软件副本和相关文档文件("软件")的人不受限制地处理本软件的权利,包括但不限于使用、复制、修改、合并、发布、分发、再许可和/或出售本软件副本的权利,以及允许向其提供本软件的人这样做,但须符合以下条件:
上述版权声明和本许可声明应包含在本软件的所有副本或实质性部分中。
本软件按"原样"提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途适用性和非侵权性的保证。在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,均不得超出与本软件或本软件的使用或其他交易有关的范围。