[ICLR'24] 一致性轨迹模型 (CTM)
本仓库包含了论文"一致性轨迹模型:学习扩散的概率流ODE轨迹"在ImageNet 64x64数据集上的官方PyTorch实现,该论文将在ICLR 2024会议上发表。
联系方式:
- Dongjun KIM: dongjun@stanford.edu
- Chieh-Hsin (Jesse) LAI: chieh-hsin.lai@sony.com
简介
对于单步扩散模型采样,我们的新模型一致性轨迹模型(CTM)在CIFAR-10(FID 1.73)和ImageNet 64x64(FID 1.92)上达到了最先进水平。CTM提供多样化的采样选项,并能有效平衡计算预算和样本保真度。
检查点
前提条件
-
下载(或获取)以下文件
- 预训练扩散模型:请将其放在
args.teacher_model_path
- 数据:请将其放在
args.data_dir
(注意我们使用的数据不是下采样的图像数据。它是ILSVRC2012数据。这两个数据集之间存在巨大的性能差异。) - 参考统计数据:用于计算FID、sFID、IS、精确度、召回率的统计数据。请将它们放在
args.ref_path
- 预训练扩散模型:请将其放在
-
在您的服务器上安装docker
2-1. 输入
docker pull dongjun57/ctm-docker:latest
从docker hub下载docker镜像。2-2. 通过输入以下命令创建容器:
docker run --gpus=all -itd -v /etc/localtime:/etc/localtime:ro -v /dev/shm:/dev/shm -v [指定目录]:[指定目录] -v /hdd/imagenet/imagenet_dir/train:/hdd/imagenet/imagenet_dir/train -v [指定数据目录]:[指定数据目录] --name ctm-docker 8caa2682d007
命令可能会因您的服务器环境而有所不同。2-3. 通过
docker exec -it ctm-docker bash
进入容器。2-4. 通过
conda activate ctm
进入虚拟环境。 -
确保依赖项与以下内容一致。
apt install git apt install libopenmpi-dev python -m pip install tensorflow[and-cuda] python -m pip install torch torchvision torchaudio python -m pip install blobfile tqdm numpy scipy pandas Cython piq==0.7.0 python -m pip install joblib==0.14.0 albumentations==0.4.3 lmdb clip@git+https://github.com/openai/CLIP.git pillow python -m pip install flash-attn --no-build-isolation python -m pip install xformers python -m pip install mpi4py python -m pip install nvidia-ml-py3 timm==0.4.12 legacy dill nvidia-ml-py3
训练
-
对于CTM+DSM训练,运行
bash commands/CTM+DSM_command.sh
建议:至少运行CTM+DSM 10~50k次迭代
-
对于CTM+DSM+GAN训练,运行
bash commands/CTM+DSM+GAN_command.sh
建议:至少运行CTM+DSM+GAN >=30k次迭代
采样
请查看commands/sampling_commands.sh
获取详细的采样命令。
评估
运行python3.8 evaluations/evaluator.py [统计数据位置] [样本位置]
第一个参数是参考路径,第二个参数是您的样本文件夹(>=50k个样本以进行正确评估)。
请参考ADM (Prafulla Dhariwal, Alex Nichol)的统计数据。
自定义数据集
用户需要手动将data_name替换为您的数据名称:在cm_train.py
或image_sample.py
中手动修改data_name
引用
@article{kim2023consistency,
title={Consistency Trajectory Models: Learning Probability Flow ODE Trajectory of Diffusion},
author={Kim, Dongjun and Lai, Chieh-Hsin and Liao, Wei-Hsiang and Murata, Naoki and Takida, Yuhta and Uesaka, Toshimitsu and He, Yutong and Mitsufuji, Yuki and Ermon, Stefano},
journal={arXiv preprint arXiv:2310.02279},
year={2023}