OmniTokenizer:用于视觉生成的联合图像-视频分词器
以下论文的官方PyTorch实现:
OmniTokenizer:用于视觉生成的联合图像-视频分词器。
王俊科1,2,蒋毅3,袁泽欢3,彭彬月3,吴祖煊1,2,姜育刚1,2
1复旦大学计算机科学学院,上海市智能信息处理重点实验室
2上海智能视觉计算协同创新中心,3字节跳动公司
我们提出了OmniTokenizer,一个联合图像-视频分词器,具有以下特点:
- 🚀 一个模型和一个权重用于联合图像和视频分词;
- 🥇 在图像和视频数据集上都达到最先进的重建性能;
- ⚡ 对高分辨率和长视频输入具有高适应性;
- 🔥 配备它后,语言模型和扩散模型都能够实现具有竞争力的视觉生成结果。
请访问我们的项目页面查看OmniTokenizer的重建和生成结果。
环境配置
请使用以下命令设置环境:
pip3 install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu118
pip3 install -r requirements.txt
然后从官方网站下载数据集。你可以下载我们处理好的annotation.zip,并将其放在./annotations
目录下。
VQVAE和VAE的模型库
我们发布了OmniTokenizer的VQVAE和VAE版本,它们在多种图像和视频数据集上进行了预训练:
类型 | 训练数据 | FID | FVD | 检查点 |
---|---|---|---|---|
VQVAE | ImageNet | 1.28[^1] | - | imagenet_only.ckpt |
VQVAE | CelebAHQ | 1.85 | - | celebahq.ckpt |
VQVAE | FFHQ | 2.58 | - | ffhq.ckpt |
VQVAE | ImageNet + UCF | 1.11 | 42.35 | imagenet_ucf.ckpt |
VQVAE | ImageNet + K600 | 1.23 | 25.97 | imagenet_k600.ckpt |
VQVAE | ImageNet + MiT | 1.26 | 19.87 | imagenet_mit.ckpt |
VQVAE | ImageNet + Sthv2 | 1.21 | 20.30 | imagenet_sthv2.ckpt |
VQVAE | CelebAHQ + UCF | 1.93 | 45.59 | celebahq_ucf.ckpt |
VQVAE | CelebAHQ + K600 | 1.82 | 89.13 | celebahq_k600.ckpt |
VQVAE | FFHQ + UCF | 1.91 | 57.93 | ffhq_ucf.ckpt |
VQVAE | FFHQ + K600 | 2.69 | 87.58 | ffhq_k600.ckpt |
VAE | ImageNet + UCF | 0.69 | 23.44 | imagenet_ucf_vae.ckpt |
VAE | ImageNet + K600 | 0.78 | 13.02 | imagenet_k600_vae.ckpt |
[^1] 我们在训练这个模型时没有使用 scaled_dot_product_attention,请注释掉 OmniTokenizer/modules/attention.py
中的第446-460行以重现这个结果。
我们推荐您尝试 imagenet_k600.ckpt,因为它是在大规模图像和视频数据上训练的。
您可以轻松地将 OmniTokenizer 整合到您的语言模型或扩散模型中,如下所示:
from OmniTokenizer import OmniTokenizer_VQGAN
vqgan = OmniTokenizer_VQGAN.load_from_checkpoint(vqgan_ckpt, strict=False)
# tokens = vqgan.encode(img)
# recons = vqgan.decode(tokens)
分词器(VQVAE 和 VAE)
VQVAE 的训练包括两个阶段:在固定分辨率上进行仅图像训练,以及在多个分辨率上进行图像-视频联合训练。之后,使用 KL 损失微调 VQVAE 模型以获得 VAE 模型。
请参考 scripts/recons/train.sh
以了解 omnitokenizer 的训练过程。根据不同设置需要更改的标志说明:
- patch_size 和 temporal_patch_size:补丁嵌入层中补丁的形状,同时决定下采样比率
- enc_block:编码器块的类型,'t'表示普通注意力,'w'表示窗口注意力
- n_codes:码本大小
- spatial_pos:空间位置编码的类型
- use_vae:以VAE模式或VQVAE模式训练
- resolution 和 sequence_length:训练的空间和时间分辨率
- resolution_scale:用于多分辨率训练,指定分辨率的比例
关于OmniTokenizer的评估,请参考scripts/recons/eval_image_inet.sh
、scripts/recons/eval_image_face.sh
和scripts/recons/eval_video.sh
。
基于语言模型的视觉合成
有关语言模型的训练和评估,请参考scripts/lm_train
和scripts/lm_gen
。我们提供了ImageNet[imagenet_class_lm.ckpt]、UCF [ucf_class_lm.ckpt]和Kinetics-600 [k600_fp_lm.ckpt]的检查点。
基于扩散的视觉合成
我们采用DiT和Latte进行基于扩散的视觉生成。有关训练和评估说明,请参考diffusion.md。
评估
有关如何评估重建或生成结果,请参考evaluation.md。
致谢
我们的代码部分基于VQGAN和TATS构建。我们也感谢pytorch-fid和common_metrics_on_video_quality提供的优秀工具。
许可证
本项目采用MIT许可证,详见LICENSE文件。