项目简介
open-muse是一个致力于重现基于Transformer的MUSE模型的开源项目,旨在快速生成文本到图像的转换。这个项目的目标是建立一个简单且可扩展的代码库,以重现MUSE模型,并在大规模的VQ加Transformer框架下积累相关知识。科研人员将在LAION-2B和COYO-700M数据集上进行训练。
项目目标和阶段
open-muse项目的目标是重现MUSE模型,并在过程中构建对视觉量化(VQ)和变压器(Transformer)在大规模应用下的理解。项目按照以下几个阶段进行:
- 建立代码基础,并在ImageNet上训练一个类条件模型。
- 在CC12M数据集上进行文本到图像的实验。
- 训练改进的VQGAN模型。
- 在LAION和COYO数据集上训练完整的(base-256)模型。
- 在LAION和COYO数据集上训练完整的(base-512)模型。
所有项目的产出将上传至huggingface平台的openMUSE组织。
如何使用
安装
首先创建一个虚拟环境,然后使用以下命令安装项目:
git clone https://github.com/huggingface/muse
cd muse
pip install -e ".[extra]"
需要手动安装PyTorch
和torchvision
,在训练中使用torch
版本为1.13.1,并搭配CUDA11.7
。项目使用accelerate
库进行分布式数据并行训练,数据集加载使用webdataset
库,因此数据集应为webdataset
格式。
模型支持
项目目前支持以下模型:
MaskGitTransformer
:项目论文中的主要Transformer模型。MaskGitVQGAN
:来自maskgit代码库的VQGAN模型。VQGANModel
:来自taming transformers代码库的VQGAN模型。
所有模型实现了熟悉的transformers
API。用户可以使用from_pretrained
和save_pretrained
方法来加载和保存模型。
基本工作原理
MaskGit是一种基于Transformer的模型,输出给定序列的logits,包括VQ和类条件标签。其去噪过程通过掩蔽标记ID,并逐步去噪完成。项目原始实现使用软最大化(softmax)来采样分类分布,这将提供每个maskid的预测标记。然后获取这些标记被选中的概率,并加入偏移的Gumbel分布以处理极端事件。
训练过程
对于类条件的ImageNet训练,使用accelerate
进行分布式数据并行(DDP)训练,数据加载使用webdataset
。项目使用OmegaConf进行配置管理,具体配置位于configs/template_config.yaml
文件中。
运行实验
当前实验在单节点上进行。执行训练步骤:
- 准备
webdataset
格式的数据集,可使用scripts/convert_imagenet_to_wds.py
脚本转换ImageNet数据集。 - 使用
accelerate config
配置训练环境。 - 为实验创建
config.yaml
文件。 - 使用
accelerate launch
启动训练。
accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config
通过此项目,研究人员期待在图像生成领域以简单易用的方式分享并进一步推动这一新兴技术的发展。