可扩展整流流变换器的最小实现
左图为简单RF,右图为对数正态时间采样RF。两者均在MNIST上训练。
本仓库包含整流流模型的最小实现。我采用了SD3的训练方法和LLaMA-DiT架构。与我之前的仓库不同,这次我决定将文件分为两部分:模型实现和实际代码,但你不必查看模型代码。
所有内容仍然是自包含、最小化的,希望易于修改。如果你理解了数学原理,就不会有任何复杂的内容。
1. 适合初学者的简单整流流
安装torch、pil、torchvision
pip install torch torchvision pillow
运行
python rf.py
以从头开始在MNIST上训练模型。
如果你想挑战一下,也可以在CIFAR上训练。
python rf.py --cifar
在第63个epoch,你的输出应该类似于:
2. 支持muP的大规模整流流
这是为想要在ImageNet上训练的高手准备的。别担心!在我看来,ImageNet是新的MNIST,我们将使用我的imagenet.int8数据集。
首先进入advanced目录,下载数据集。
cd advanced
pip install hf_transfer # 只需安装这个
bash download.sh
如果你的网络不错,这应该不会超过5分钟。
运行
bash run.sh
来训练模型。这将从头开始在ImageNet上训练模型,进行muP网格搜索以找到损失函数的对齐盆地,你将解锁整流流模型的零样本LR迁移!
这里使用了我过去一年开发的多种技术和代码库。它是min-max-IN-dit、min-max-gpt和ez-muP的自然结合。
引用
如果你使用了这些材料,请使用以下方式引用本仓库:
@misc{ryu2024minrf,
author = {Simo Ryu},
title = {minRF: Minimal Implementation of Scalable Rectified Flow Transformers},
year = 2024,
publisher = {Github},
url = {https://github.com/cloneofsimo/minRF},
}