Unified-IO 2
本仓库包含 Unified-IO 2 的代码,包括运行演示、训练和推理的代码。此代码库是基于 T5X 修改而来。
新闻:
-
[2024年2月15日] 我们发布了 Unified-IO 2 的 Pytorch 代码。详情可在此处查看。
-
[2024年1月5日] 我们发布了用于训练音频分词器的 VIT-VQGAN 的 JAX 源代码。详情可在此处查看。
安装
使用 pip 安装依赖项
- 注意: 由于这个项目开发时间较长,我们使用的一些包是旧版本。我们最近发现,在使用 Python 3.9 时,导入
orbax.checkpoint
可能会导致 JAX 中dtype="bfloat16"
的冲突,但在 Python 3.8 (例如 3.8.10,这是 TPU VM 的默认版本)中仍然可以正常工作。这个问题可能是由于 orbax.checkpoint 和 pip 的内部变化导致的。
对于 TPU:
python3 -m pip install -e '.[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_releases.html
对于 GPU/CPU(注意我们一直使用 TPU,所以 GPU 设置未经充分测试):
python3 -m pip install -e '.' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
运行演示需要额外的依赖项,使用以下命令安装:
python3 -m pip install -e '.[demo]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -f https://storage.googleapis.com/jax-releases/jax_releases.html
还需要安装 LLaMa 分词器,从 https://github.com/facebookresearch/llama/tree/main?tab=readme-ov-file 下载 .model
文件,然后更新 t5x/examples/unified_io/config.py
,使 LLAMA_TOKENIZER_PATH
指向下载位置。
检查点
我们在 S3 上提供 T5X 格式的检查点:
- XXL: s3://ai2-prior-uio/public/uio2-checkpoints/xxl-3m
- XL: s3://ai2-prior-uio/public/uio2-checkpoints/xl-3m
- Large: s3://ai2-prior-uio/public/uio2-checkpoints/large-3m
要下载,请递归复制目录。例如:
aws s3 --no-sign-request cp --recursive s3://ai2-prior-uio/public/uio2-checkpoints/large-3m large-3m --exclude "state*"
它们应该被复制到本地磁盘或 Google 文件存储。这里,--exclude "state*"
标志排除了优化器状态的下载,如果你想从当前优化器状态继续训练检查点,可以移除此标志。
演示
要交互式运行模型,可以运行演示笔记本。 确保已安装演示依赖项。
然后运行演示笔记本:
jupyter notebook demo.ipynb
在第二个单元格中设置 FULL_CKPT_PATH
和 MODEL_TYPE
为你的检查点和正确的模型大小。然后可以使用笔记本启动演示。
演示展示了如何加载模型、参数和进行推理。
首次使用时,演示会很慢,因为需要编译推理函数,之后使用类似输入/输出的调用会快得多。
数据
要在整个数据集上进行训练和评估,需要在 seqio.TaskRegistry
中注册数据集。参见 t5x/examples/unifiedio/data/tasks.py
中的示例。有关 seqio 如何管理数据集的更多详细信息,请参阅 seqio。
某些数据集在使用前需要运行预处理脚本。
确保更新 config.MULTITASK_TFDS_DATA_DIR
以指向存储数据集的位置。
数据集
我们在 t5x/examples/unifiedio/data/tasks.py
中提供了一些初始数据集。
我们的数据集通常以以下三种方式之一构建:
- 构建为
tensorflow_dataset
然后上传到config.MULTITASK_TFDS_DATA_DIR
指定的位置 - 构建为一组 tfrecords 并上传到同一位置
- 直接使用 https://www.tensorflow.org/datasets/catalog/overview 中的数据集
以第一种或第二种方式构建的数据集在使用前需要运行构建脚本。create_data
包含所需的构建脚本。例如,运行:
python3 create_data/tfdatasets/coco_all/build.py ~/data/tfds ~/data/vqa ~/data/coco_annotations
将上传 COCO 数据的 tfdataset,这允许使用诸如 image_generation_coco_2017
和 image_caption_coco_2017
等任务。一些数据集,如使用公共 tensorflow 目录的 refexp 数据集,可能还有自己的手动预处理步骤,这些步骤将在其网页上指定。
UnifiedIO 2 包含大量任务,对于这个初始版本,我们只包含了一部分,但我们将在测试和验证更多任务后添加更多。
预处理
UIO2 中的预处理分三个阶段进行:
-
任务特定的预处理构建提示并在支持的模态中构建输入和输出。这个阶段需要调整图像大小并填充到正确的尺寸,并提供遮罩来显示图像的哪些部分是填充(通常使用
unified_io.data.data_utils.resize_and_pad
)。 音频片段需要转换为梅尔频谱图,如果处理噪声数据,也可以进行遮罩。这个阶段由unified_io.data.preprocessing
中的各种预处理函数实现。 演示展示了如何处理原始输入。 为了允许这个阶段在训练和测试期间进行不同的预处理, 我们在sequence_length字典中传递一个is_training
字段,以指示 数据集是用于训练还是测试。 -
接下来运行
modality_processing.unified_io_preprocessor
。这个函数执行各种任务通用的预处理步骤, 例如对文本进行分词,并为缺失的模态添加空值,以便输出数据集具有一致的字段集。 -
最后应用
UnifiedIOFeatureConverter
,这可以在 多个数据集被合并成seqio.Mixture
之后进行。 这个函数将确保输出数据集具有一致的结构,并填充为 固定大小的张量,这是jax所需要的。这个数据集现在可以被批处理并直接传递到 UnifiedIO 2模型的损失或预测函数中。 填充由sequence_len字典决定。
要添加数据集,请使用seqio注册它,并确保最后的预处理器
是modality_processing.unified_io_preprocessor
。前面的
函数应确保数据集具有该函数所需的适当字段。
提示
我们在t5x/examples/unified_io/data/prompt_dict
中有完整的提示集,
在训练过程中我们随机选择这些提示。
可视化
我们包含了一个可视化脚本,用于显示后处理后的数据:
python3 t5x/examples/unified_io/scripts/dataset_visualize.py refcoco_unc viz --override
要获得更紧凑的视图:
python3 t5x/examples/unified_io/scripts/dataset_visualize.py refcoco_unc viz --override --gin.get_target_modalities.target_modality=[\"text\"] --gin.get_input_modalities.input_modality=[\"text\",\"image\"] --nomasks
训练
一旦下载了检查点并准备好数据集,就可以使用train.py进行训练。
我们的训练策略主要遵循T5X,通过gin进行配置。
按照https://github.com/google-research/t5x
的设置在TPU上训练。
例如,要在refexp上微调大型模型:
python3 t5x/train.py --gin_file=t5x/examples/unified_io/t5_1_1/large.gin --gin_file=t5x/examples/unified_io/t5_1_1/finetune/refexp.gin --gin.INITIAL_CHECKPOINT_PATH=\"/path/to/checkpoint\" --gin.MODEL_DIR=\"path/to/output_dir\" --gin.BATCH_SIZE=8
模态
UnifiedIO 2可以在支持的模态的子集上运行,这使训练更
高效。这可以通过get_input_modalities
和get_target_modalities
中的gin配置参数来设置。例如,refexp.gin
只开启了图像/文本输入和文本输出。
序列长度
由于jax的固定大小张量约束,我们默认将所有输入和目标填充到
模型支持的最大长度。当在这种做法过度的混合数据上训练时,
可以通过更改seqio
使用的sequence_lengths来调整
例如,refexp.gin减少了输入和输出序列长度,因为
refexp的文本较少。
Wandb
我们修改了train.py以使用wandb,只需确保设置了WANDB_API_KEY
环境变量。
应修改或通过gin配置可配置函数utils.init_wandb
以选择正确的名称/组/项目/实体。
打包
如果训练混合包含长短不一的样本,打包 可以提高效率。打包将最多将两个样本打包在一起 成为单个输入序列,可以通过以下标志开启:
--gin.PackingStrategy.pack_max_len=(864, 1280)
在训练期间,将尝试将两个样本打包在总 输入长度为864和目标长度为1280的序列中。一个启发式算法 将尝试在数据流向训练服务器时找到符合此标准的样本对,如果找不到,则只使用一个样本。 如果这种情况发生得太频繁,最好增加最大长度。 统计数据将记录到wandb以跟踪打包效率。
评估
评估脚本使用eval.py运行,例如:
python3 t5x/eval.py --gin_file=t5x/examples/unified_io/t5_1_1/large.gin --gin_file=t5x/examples/unified_io/t5_1_1/eval/vision_language.gin --gin.CHECKPOINT_PATH=\"large-3m\" --gin.MIXTURE_OR_TASK_NAME=\"refcoco_unc\" --gin.EVAL_OUTPUT_DIR=\"output\"
目标数据集必须在seqio中注册指标。评估脚本 同样可以通过只使用所需的模态和 适当选择序列长度来提高效率。请注意,我们的大多数官方结果 来自收集输出然后运行离线评估,这里的指标 主要用于验证分数。
引用
@article{lu2023uio2,
title = {Unified-IO 2: Scaling Autoregressive Multimodal Models with Vision, Language, Audio, and Action},
author = {Jiasen Lu and Christopher Clark and Sangho Lee and Zichen Zhang and Savya Khosla and Ryan Marten and Derek Hoiem and Aniruddha Kembhavi},
journal = {arXiv preprint arXiv:2312.17172},
year = {2023},
}