open-muse
一个开放再现项目,用于复现基于Transformer的MUSE模型,以快速生成从文本到图像。
演示
👉 https://huggingface.co/spaces/openMUSE/MUSE
目标
本仓库用于再现MUSE模型。目标是创建一个简单且可扩展的仓库,以再现MUSE并在大规模上构建关于VQ + Transformer的知识。我们将使用去重的LAION-2B + COYO-700M数据集进行训练。
项目阶段:
- 设置代码库并在ImageNet上训练一个类条件模型。
- 在CC12M上进行文本到图像的实验。
- 训练改进的VQGAN模型。
- 在LAION + COYO上训练完整的(base-256)模型。
- 在LAION + COYO上训练完整的(base-512)模型。
该项目的所有成果都将上传到huggingface hub上的openMUSE组织。
使用方法
安装
首先创建一个虚拟环境并使用以下命令安装仓库:
git clone https://github.com/huggingface/muse
cd muse
pip install -e ".[extra]"
您需要手动安装PyTorch
和torchvision
。我们使用torch==1.13.1
和CUDA11.7
进行训练。
对于分布式数据并行训练,我们使用accelerate
库,但未来可能会有变化。对于数据集加载,我们使用webdataset
库。因此,数据集应为webdataset
格式。
模型
目前我们支持以下模型:
MaskGitTransformer
- 论文中主要的Transformer模型。MaskGitVQGAN
- 来源于maskgit仓库的VQGAN模型。VQGANModel
- 来源于taming transformers仓库的VQGAN模型。
这些模型在muse
目录下实现。所有模型都实现了熟悉的transformers
API。因此,您可以使用from_pretrained
和save_pretrained
方法来加载和保存模型。模型可以保存并从huggingface hub加载。
VQGAN 示例:
import torch
from torchvision import transforms
from PIL import Image
from muse import MaskGitVQGAN
# 从hub加载预训练的vq模型
vq_model = MaskGitVQGAN.from_pretrained("openMUSE/maskgit-vqgan-imagenet-f16-256")
# 编码和解码图像
encode_transform = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(256),
transforms.ToTensor(),
]
)
image = Image.open("...") #
pixel_values = encode_transform(image).unsqueeze(0)
image_tokens = vq_model.encode(pixel_values)
rec_image = vq_model.decode(image_tokens)
# 转换为PIL图像
rec_image = 2.0 * rec_image - 1.0
rec_image = torch.clamp(rec_image, -1.0, 1.0)
rec_image = (rec_image + 1.0) / 2.0
rec_image *= 255.0
rec_image = rec_image.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
pil_images = [Image.fromarray(image) for image in rec_image]
用于类条件生成的MaskGitTransformer示例:
import torch
from muse import MaskGitTransformer, MaskGitVQGAN
from muse.sampling import cosine_schedule
# 从hub加载预训练的vq模型
vq_model = MaskGitVQGAN.from_pretrained("openMUSE/maskgit-vqgan-imagenet-f16-256")
# 初始化MaskGitTransformer模型
maskgit_model = MaskGitTransformer(
vocab_size=2025, #(1024 + 1000 + 1 = 2025 -> Vq_tokens + ImageNet类ID + <mask>)
max_position_embeddings=257, # 256 + 1表示类标记
hidden_size=512,
num_hidden_layers=8,
num_attention_heads=8,
intermediate_size=2048,
codebook_size=1024,
num_vq_tokens=256,
num_classes=1000,
)
# 准备输入批次
images = torch.randn(4, 3, 256, 256)
class_ids = torch.randint(0, 1000, (4,)) # 随机类ID
# 编码图像
image_tokens = vq_model.encode(images)
batch_size, seq_len = image_tokens.shape
# 为每张图像随机采样一个时间步
timesteps = torch.rand(batch_size, device=image_tokens.device)
# 使用时间步和余弦调度为每张图像随机采样一个掩码概率
mask_prob = cosine_schedule(timesteps)
mask_prob = mask_prob.clip(min_masking_rate)
# 为每张图像创建一个随机掩码
num_token_masked = (seq_len * mask_prob).round().clamp(min=1)
batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)
mask = batch_randperm < num_token_masked.unsqueeze(-1)
# 掩码图像并创建输入和标签
input_ids = torch.where(mask, mask_id, image_tokens)
labels = torch.where(mask, image_tokens, -100)
# 将类ID按代码本大小偏移
class_ids = class_ids + vq_model.num_embeddings
# 将类ID预置于图像标记前
input_ids = torch.cat([class_ids.unsqueeze(-1), input_ids], dim=-1)
# 在标签前预置-100,因为我们不想预测类ID
labels = torch.cat([-100 * torch.ones_like(class_ids).unsqueeze(-1), labels], dim=-1)
# 前向传播
logits, loss = maskgit_model(input_ids, labels=labels)
loss.backward()
# 生成图像
class_ids = torch.randint(0, 1000, (4,)) # 随机类ID
generated_tokens = maskgit_model.generate(class_ids=class_ids)
rec_images = vq_model.decode(generated_tokens)
注意:
- vq模型和Transformer模型是分开的,以便独立扩展Transformer模型。而且我们可能会预先编码图像以加快训练速度。
- 掩码操作在模型之外进行,以便使用不同的掩码策略而不影响建模代码。
MaskGit生成过程的基本解释
-
MaskGit是一个Transformer,当给定一系列vq标记和类条件标签标记时输出logits。
-
去噪过程的方式是使用掩码标记ID并逐渐去噪。
-
在原始实现中,这是通过对最后一个维度使用softmax并随机抽样作为分类分布来完成的。这将给我们每个掩码ID的预测标记。然后我们得到这些标记被选择的概率。最后,当加入Gumbel*temp时,我们得到最高置信度的概率值。Gumbel分布类似于向0偏移的正态分布,用于模拟极端事件。因此在极端情况下,我们会喜欢选择与默认不同的标记。
-
在Lucidrian实现中,它首先通过给定的掩码率掩码最低得分(最高概率)的标记。然后除了我们得到的最高10%的logits外,我们设为-无穷,因此当我们对其执行Gumbel分布时,这些将被忽略。然后更新输入ID和分数,其中分数只是由logits在预测ID下的softmax给予的概率的1。
训练
对于类条件ImageNet,我们使用accelerate
进行DDP训练,使用webdataset
进行数据加载。训练脚本位于training/train_maskgit_imagenet.py
。
配置管理使用OmegaConf。请参阅configs/template_config.yaml
获取配置模板。下面我们解释配置参数。
wandb:
entity: ???
experiment:
name: ???
project: ???
output_dir: ???
max_train_examples: ???
save_every: 1000
eval_every: 500
generate_every: 1000
log_every: 50
log_grad_norm_every: 100
resume_from_checkpoint: latest
model:
vq_model:
pretrained: "openMUSE/maskgit-vqgan-imagenet-f16-256"
transformer:
vocab_size: 2048 # (1024 + 1000 + 1 = 2025 -> Vq + Imagenet + <mask>, 使用2048便于被8整除)
max_position_embeddings: 264 # (256 + 1表示类ID,使用264便于被8整除)
hidden_size: 768
num_hidden_layers: 12
num_attention_heads: 12
intermediate_size: 3072
codebook_size: 1024
num_vq_tokens: 256
num_classes: 1000
initializer_range: 0.02
layer_norm_eps: 1e-6
use_bias: False
use_normformer: True
use_encoder_layernorm: True
hidden_dropout: 0.0
attention_dropout: 0.0
gradient_checkpointing: True
enable_xformers_memory_efficient_attention: False
dataset:
params:
train_shards_path_or_url: ???
eval_shards_path_or_url: ???
batch_size: ${training.batch_size}
shuffle_buffer_size: ???
num_workers: ???
resolution: 256
pin_memory: True
persistent_workers: True
preprocessing:
resolution: 256
center_crop: True
random_flip: False
optimizer:
name: adamw # 可以是adamw或lion或fused_adamw。安装apex用于fused_adamw
params: # 默认adamw参数
learning_rate: ???
scale_lr: False # 按总批次大小缩放学习率
beta1: 0.9
beta2: 0.999
weight_decay: 0.01
epsilon: 1e-8
lr_scheduler:
scheduler: "constant_with_warmup"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 500
training:
gradient_accumulation_steps: 1
batch_size: 128
mixed_precision: "no"
enable_tf32: True
use_ema: False
seed: 42
max_train_steps: ???
overfit_one_batch: False
min_masking_rate: 0.0
label_smoothing: 0.0
max_grad_norm: null
带有 ??? 的参数是必填的。
wandb:
wandb.entity
:用于记录的wandb实体。
experiment:
experiment.name
:实验名称。experiment.project
:用于记录的wandb项目。experiment.output_dir
:保存检查点的目录。experiment.max_train_examples
:使用的最大训练示例数。experiment.save_every
:每save_every
步保存一次检查点。experiment.eval_every
:每eval_every
步评估一次模型。experiment.generate_every
:每generate_every
步生成一次图像。experiment.log_every
:每log_every
步记录一次训练指标。log_grad_norm_every
:每log_grad_norm_every
步记录一次梯度范数。experiment.resume_from_checkpoint
:从哪个检查点继续训练。可以是“latest”表示从最新的检查点继续或保存检查点的路径。如果为None
或路径不存在,则从头开始训练。
model:
model.vq_model.pretrained
:要使用的预训练vq模型。可以为保存的检查点路径或huggingface模型名称。model.transformer
:Transformer模型配置。model.gradient_checkpointing
:启用Transformer模型的梯度检查点。enable_xformers_memory_efficient_attention
:启用Transformer模型的内存高效注意力或Flash注意力。对于Flash注意力,我们需要使用fp16
或bf16
。xformers 需要安装才能工作。
dataset:
dataset.params.train_shards_path_or_url
:webdataset
训练分片的路径或URL。dataset.params.eval_shards_path_or_url
:webdataset
评估分片的路径或URL。dataset.params.batch_size
:用于训练的批次大小。dataset.params.shuffle_buffer_size
:用于训练的混洗缓冲区大小。dataset.params.num_workers
:用于数据加载的工作线程数量。dataset.params.resolution
:用于训练的图像分辨率。dataset.params.pin_memory
:针对此数据加载的内存。dataset.params.persistent_workers
:使用持久化的工作线程进行数据加载。dataset.preprocessing.resolution
:用于预处理的图像分辨率。dataset.preprocessing.center_crop
:是否中心裁剪图像。如果为False,则图像被随机裁剪到resolution
。dataset.preprocessing.random_flip
:是否随机翻转图像。如果为False,则图像不翻转。 优化器:optimizer.name
:用于训练的优化器。optimizer.params
:优化器参数。
学习率调度器:
lr_scheduler.scheduler
:用于训练的学习率调度器。lr_scheduler.params
:学习率调度器参数。
训练:
training.gradient_accumulation_steps
:用于训练的梯度累积步骤数。training.batch_size
:用于训练的批量大小。training.mixed_precision
:用于训练的混合精度模式。可以是no
,fp16
或bf16
。training.enable_tf32
:在 Ampere GPU 上启用 TF32 训练。training.use_ema
:启用 EMA 训练。目前不支持。training.seed
:用于训练的种子。training.max_train_steps
:最大训练步骤数。training.overfit_one_batch
:是否为调试过拟合一个批次。training.min_masking_rate
:用于训练的最小掩码率。training.label_smoothing
:用于训练的标签平滑值。max_grad_norm
:最大梯度范数。
关于训练和数据集的说明:
每当我们恢复/开始训练运行时,我们会随机重新采样数据片(带有替换)并缓冲区中的示例进行训练。这意味着我们的数据加载不是确定性的。我们也不进行基于 epoch 的训练,仅用于记录和能够与其他数据集/加载器重复使用相同的训练循环。
运行实验:
目前,我们正在单节点上运行实验。要在单节点上启动训练运行,请执行以下步骤:
- 准备
webdataset
格式的数据集。您可以使用scripts/convert_imagenet_to_wds.py
脚本将 imagenet 数据集转换为webdataset
格式。 - 首先使用
accelerate config
配置您的训练环境。 - 为您的实验创建一个
config.yaml
文件。 - 使用
accelerate launch
启动训练运行。
accelerate launch python -u training/train_maskgit_imagenet.py config=path/to/yaml/config
使用 OmegaConf 时,命令行覆盖以点符号格式完成。例如,如果您想覆盖数据集路径,您可以使用以下命令python -u train.py config=path/to/config dataset.params.path=path/to/dataset
。
同样的命令可以用于在本地启动训练。
步骤
设置代码库并在 imagenet 上训练类条件模型。
- 设置仓库结构
- 添加 transformers 和 VQGAN 模型。
- 为模型添加生成支持。
- 从 maskgit 仓库移植 VQGAN 进行 imagenet 实验。
- 完成并验证掩码工具。
- 从 MUSE 中添加掩码弧度调度函数。
- 添加 EMA。
- 支持使用 OmegaConf 配置训练。
- 添加 W&B 日志工具。
- 添加 WebDataset 支持。对于 imagenet 实验并不真正需要,但可以并行工作。(LAION 已经是这种格式,所以更容易使用它)。
- 添加使用 imagenet 进行类条件生成的训练脚本。
- 为集群训练准备代码库。添加 SLURM 脚本。
在 CC12M 上进行文本生成图像实验。
- 完成数据加载和预处理工具。
- 添加 CLIP 和 T5 支持。
- 添加文本生成图像的训练脚本。
- 添加评估脚本(FiD、CLIP 分数)。
- 在 CC12M 上进行训练。我们可以进行不同的实验:
- 使用 T5 条件在 CC12M 上训练。
- 使用 CLIP 条件在 CC12M 上训练。
- 使用 CLIP + T5 条件在 CC12M 上训练(可能在训练和实验期间成本较高)。
- 自 Bit Diffusion 论文中的自我条件。
- 收集不同的中间评估提示(可以重复使用 dalle-mini,parti-prompts 的提示)。
- 设置一个空间,让人们可以玩模型并提供反馈,与其他模型进行比较等。
训练改进的 VQGAN 模型。
- 为 VQGAN 添加训练组件模型(EMA,鉴别器,LPIPS 等)。
- VQGAN 训练脚本。
杂项任务
- 创建一个空间以可视化探索数据集
- 创建一个空间,供人们尝试找到自己的图像并可以选择退出数据集。
仓库结构(在进行中)
├── README.md
├── configs -> 所有训练配置文件。
│ └── template_config.yaml
├── muse
│ ├── __init__.py
│ ├── data.py -> 所有数据相关工具。需要时可以创建数据文件夹。
│ ├── logging.py -> 各种日志工具。
| ├── lr_schedulers.py -> 所有学习率调度器相关工具。
│ ├── modeling_maskgit_vqgan.py -> 来自 maskgit 仓库的 VQGAN 模型。
│ ├── modeling_taming_vqgan.py -> 来自 taming 仓库的 VQGAN 模型。
│ └── modeling_transformer.py -> 主要的 Transformer 模型。
│ ├── modeling_utils.py -> 所有模型相关工具,例如 save_pretrained, from_pretrained from hub 等。
│ ├── sampling.py -> 采样/生成工具。
│ ├── training_utils.py -> 常见训练工具。
├── pyproject.toml
├── setup.cfg
├── setup.py
├── test.py
└── training -> 所有训练脚本。
├── __init__.py
├── data.py -> 所有数据相关工具。需要时可以创建数据文件夹。
├── optimizer.py -> 所有优化器相关工具以及任何 PT 中没有的新优化器。
├── train_maskgit_imagenet.py
├── train_muse.py
└── train_vqgan.py
鸣谢
该项目主要基于以下开源仓库。感谢所有作者的精彩工作。
- muse-maskgit-pytorch . 特别感谢 @lucidrains 的精彩工作 ❤️
- maskgit
- taming-transformers
- open-clip
- open-diffusion
- dalle-mini: ❤️
- transformers
- accelerate
- diffusers
- webdataset
当然,也感谢 PyTorch 团队提供了这个惊人的框架 ❤️