基于Transformer的可扩展扩散模型 (DiT)
改进的PyTorch实现
论文 | 项目页面 | 运行 DiT-XL/2
本仓库提供了论文基于Transformer的可扩展扩散模型的改进PyTorch实现。
它包含:
- 🪐 DiT的改进PyTorch实现和原始实现
- ⚡️ 在ImageNet上训练的预训练类条件DiT模型(512x512和256x256)
- 💥 用于运行预训练DiT-XL/2模型的独立Hugging Face Space和Colab笔记本
- 🛸 改进的DiT训练脚本和多个训练选项
设置
首先,下载并设置仓库:
git clone https://github.com/chuanyangjin/fast-DiT.git
cd DiT
我们提供了一个environment.yml
文件,可用于创建Conda环境。如果你只想在CPU上本地运行预训练模型,可以从文件中删除cudatoolkit
和pytorch-cuda
要求。
conda env create -f environment.yml
conda activate DiT
采样
预训练DiT检查点。 你可以使用sample.py
从我们的预训练DiT模型中采样。根据你使用的模型,预训练DiT模型的权重将自动下载。该脚本有多个参数,可以在256x256和512x512模型之间切换,调整采样步骤,更改无分类器引导尺度等。例如,要从我们的512x512 DiT-XL/2模型中采样,你可以使用:
python sample.py --image-size 512 --seed 1
为方便起见,我们的预训练DiT模型也可以直接在这里下载:
自定义 DiT 检查点。 如果你使用 train.py
训练了新的 DiT 模型(参见下文),你可以添加 --ckpt
参数来使用你自己的检查点。例如,要从自定义的 256x256 DiT-L/4 模型的 EMA 权重进行采样,运行:
python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt
训练
训练前准备
要在一个节点上使用 1
个 GPU 提取 ImageNet 特征:
torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --model DiT-XL/2 --data-path /path/to/imagenet/train --features-path /path/to/store/features
训练 DiT
我们在 train.py
中提供了 DiT 的训练脚本。这个脚本可以用来训练类别条件的 DiT 模型,但也可以轻松修改以支持其他类型的条件。
要在一个节点上使用 1
个 GPU 启动 DiT-XL/2(256x256)训练:
accelerate launch --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features
要在一个节点上使用 N
个 GPU 启动 DiT-XL/2(256x256)训练:
accelerate launch --multi_gpu --num_processes N --mixed_precision fp16 train.py --model DiT-XL/2 --features-path /path/to/store/features
另外,你也可以选择提取和训练位于 training options 文件夹中的脚本。
PyTorch 训练结果
我们使用 PyTorch 训练脚本从头开始训练了 DiT-XL/2 和 DiT-B/4 模型,以验证它能够在数十万次训练迭代中重现原始 JAX 的结果。在我们的实验中,PyTorch 训练的模型与 JAX 训练的模型相比,在合理的随机变化范围内,给出了相似(有时甚至略好)的结果。以下是一些数据点:
DiT 模型 | 训练步数 | FID-50K (JAX 训练) | FID-50K (PyTorch 训练) | PyTorch 全局训练种子 |
---|---|---|---|---|
XL/2 | 400K | 19.5 | 18.1 | 42 |
B/4 | 400K | 68.4 | 68.9 | 42 |
B/4 | 400K | 68.4 | 68.3 | 100 |
这些模型在 256x256 分辨率下训练;我们使用 8 个 A100 GPU 训练 XL/2,4 个 A100 GPU 训练 B/4。注意,此处的 FID 是使用 250 步 DDPM 采样计算得出的,使用 mse
VAE 解码器,且没有引导(cfg-scale=1
)。
提升训练性能
与原始实现相比,我们实施了一系列训练速度加速和内存节省功能,包括梯度检查点、混合精度训练和预提取VAE特征,在DiT-XL/2上实现了95%的速度提升和60%的内存减少。以下是使用A100、全局批量大小为128的一些数据点:
梯度检查点 | 混合精度训练 | 特征预提取 | 训练速度 | 内存使用 |
---|---|---|---|---|
❌ | ❌ | ❌ | - | 内存不足 |
✔ | ❌ | ❌ | 0.43步/秒 | 44045 MB |
✔ | ✔ | ❌ | 0.56步/秒 | 40461 MB |
✔ | ✔ | ✔ | 0.84步/秒 | 27485 MB |
评估(FID、Inception Score等)
我们提供了一个sample_ddp.py
脚本,可以并行地从DiT模型中采样大量图像。这个脚本生成一个包含样本的文件夹以及一个.npz
文件,可以直接用于ADM的TensorFlow评估套件来计算FID、Inception Score和其他指标。例如,要使用N个GPU从我们预训练的DiT-XL/2模型中采样50K张图像,运行:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000
还有其他几个选项;详情请参见sample_ddp.py
。