multi_token
将任意模态(图像、音频、文档等)嵌入到大型语言模型中。
这个库旨在作为LLaVA的扩展,用于将✨任何东西✨(图像、声音、文档、视频、动作捕捉、屏幕截图、语音录音等)编码成可用于大型语言模型的格式。它的主要贡献在于能够将多个实例和模态嵌入到单个模型中,并提供了一个相对简单的框架来实现这一目标。
利用这个库,你可能可以向大型多模态模型(LMMs)提出以下问题:
-
阅读<文档>并给我一个摘要。
-
听<音频>并回答口头提出的问题。
-
比较和对比<图像>和<图像>
-
根据<屏幕截图>和<游戏状态>,我应该按哪个键?
想了解这是如何工作的吗?请查看这篇博客文章。
使用方法
git clone https://github.com/sshh12/multi_token \
&& cd multi_token \
&& pip install -r requirements.txt \
&& pip install -e .
pip install flash-attn --no-build-isolation
模型库
⚠️ 如果遇到缺少adapters.bin
的问题,请参见 https://github.com/sshh12/multi_token/issues/12 ⚠️
基础模型 | 模型 | 模态 | 备注 |
---|---|---|---|
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-LoRA-DocumentGTE-16K-x8 | 长文档 将文档编码为一系列 <document> 并使用documents 。 | ⚠️📚 在维基百科上预训练并在LongAlpaca和Long-Data-Collections上微调的压缩模型。使用gte-large将512个token的块压缩为64个,结果可能会有较大损失。其性能与x128版本相似,表明瓶颈可能是嵌入模型本身。 计算资源:约100个A6000 GPU小时 |
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-LoRA-DocumentGTE-260K-x128 | 长文档 将文档编码为一系列 <document> 并使用documents 。 | ⚠️📚 在维基百科上预训练并在LongAlpaca和Long-Data-Collections上微调的压缩模型。使用gte-large将512个token的块压缩为仅4个,结果可能会有较大损失。 计算资源:约50个A6000 GPU小时 |
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-LoRA-ImageBind-LLAVA | ImageBind(视觉/音频/文本) 将音频或图像文件名编码为 <imagebind> 并使用imagebinds 。 | ⚠️🖼️🔊📚 在增强的LLaVA数据集上预训练和微调的模型。可能会从音频中幻想出颜色,需要明确提及输入是声音/图像/文档。 计算资源:约180个4090 GPU小时 |
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-LoRA-VisionCLIP-LLAVA | 视觉 将图像编码为 <image> 并使用images 。 | ⭐🖼️ 在LLaVA数据集上预训练和微调的模型。应该可与BakLLaVA和LLaVA 1.5相媲美。 计算资源:约160个3090 Ti GPU小时 |
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-LoRA-VisionCLIPPool-LLAVA | 视觉 将图像编码为 <image> 并使用images 。 | ⭐🖼️ 在LLaVA数据集上预训练和微调的模型。应该可与BakLLaVA和LLaVA 1.5相媲美。使用CLIP的最后一层编码为10个token(而不是原始的576个)。 计算资源:约100个A6000 GPU小时 |
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-LoRA-Multi-VisionCLIPPool-LLAVA | 视觉 将图像编码为 <image><image>... 并使用 images 。 | ⭐🖼️🖼️ 一个在LLaVA数据集和合成多图像数据集上预训练和微调的模型。每张图像编码为10个标记,最多支持6张图像。 计算量:约100个A6000 GPU小时 |
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-CLIP-LoRA-captions-only-demo | 视觉 将图像编码为 <image> 并使用 images 。 | ⚠️🖼️ 这是一个__非常有限__的图像模型,仅在少量__仅包含描述__的示例上训练,目的是展示概念验证。 计算量:约10个3090 Ti GPU小时 |
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-LoRA-XCLIP | 视频 将视频编码为 <video> 并使用 videos 。 | ⚠️🎥 这是一个__非常有限__的视频模型。很难找到好的视频描述数据集,所以这个模型训练不足。 计算量:约50个A6000 GPU小时 |
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-LoRA-AudioWhisper | 音频(语音) 将音频编码为 <speech> 并使用 speech_audios 。 | ⚠️🔊 一个在commonvoice上预训练并在GPT3.5合成数据集上微调的模型。这个模型训练不足,效果不是很好(也基于whisper-small),但勉强可用。 计算量:约60个A6000 GPU小时 |
mistralai/Mistral-7B-Instruct-v0.1 | sshh12/Mistral-7B-LoRA-AudioCLAP | 音频(声音) 将音频编码为 <sound> 并使用 sounds 。 | ⚠️🔊 一个在 Chr0my/Epidemic_sounds 上预训练并在GPT3.5合成数据集上微调的模型。这个模型训练不足,但效果似乎还可以。 计算量:约30个A6000 GPU小时 |
视觉
LLaVA等效
python scripts/serve_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_lora_path sshh12/Mistral-7B-LoRA-VisionCLIP-LLAVA \
--load_bits 4 \
--port 7860
requests.post(
"http://localhost:7860/generate",
json={
"messages": [{"role": "user", "content": "我在参观这个地方时应该注意什么?<image>"}],
"images": ["https://github.com/sshh12/multi_token/raw/main/.demo/llava-view.jpg"],
},
).json()
# {'output': '在参观这个有木质码头的湖泊时,有几点需要注意。首先,要注意水深和是否有隐藏的障碍物,如岩石或水下碎片,这些可能会威胁到您的安全。其次,要留意天气状况,因为天气的突然变化可能会使水域变得不可预测和潜在危险。最后,要警惕该地区的任何野生动物或海洋生物,它们可能会威胁到您的安全或损坏码头。'}
多图像
python scripts/serve_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_lora_path sshh12/Mistral-7B-LoRA-Multi-VisionCLIPPool-LLAVA \
--port 7860
requests.post(
"http://localhost:7860/generate",
json={
"messages": [{"role": "user", "content": "<image><image> 这两张图片在颜色上有什么区别?"}],
"images": ["https://github.com/sshh12/multi_token/raw/main/.demo/wiki-pink-flower.jpg", "https://github.com/sshh12/multi_token/raw/main/.demo/wiki-yellow-flower.jpg"],
},
).json()
# {'output': '第一张图片是粉色花朵,而第二张图片是黄色花朵。'}
语音
python scripts/serve_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_lora_path sshh12/Mistral-7B-LoRA-AudioWhisper \
--port 7860
requests.post(
"http://localhost:7860/generate",
json={
"messages": [{"role": "user", "content": "说的是什么?<speech>"}],
"speech_audios": ["https://github.com/sshh12/multi_token/raw/main/.demo/test.mp3"],
},
).json()
# {'output': '这是一个测试。'}
声音
python scripts/serve_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_lora_path sshh12/Mistral-7B-LoRA-AudioCLAP \
--port 7860
requests.post(
"http://localhost:7860/generate",
json={
"messages": [{"role": "user", "content": "是什么在发出这个声音?<sound>"}],
"sounds": ["https://github.com/sshh12/multi_token/raw/main/.demo/imagebind-dog-audio.wav"],
},
).json()
# {'output': '这个声音是吉娃娃在叫。'}
视频
python scripts/serve_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_lora_path sshh12/Mistral-7B-LoRA-XCLIP \
--port 7860
requests.post(
"http://localhost:7860/generate",
json={
"messages": [{"role": "user", "content": "<video> 视频中展示了什么乐器?"}],
"videos": ["https://www.youtube.com/watch?v=3569sBBgVsc"],
},
).json()
# {'output': '一个男人在房间里弹钢琴'}
ImageBind(视觉/音频/文本)
python scripts/serve_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_lora_path sshh12/Mistral-7B-LoRA-ImageBind-LLAVA \
--port 7860
requests.post(
"http://localhost:7860/generate",
json={
"messages": [{"role": "user", "content": "<imagebind> 这个声音中的动物是什么?"}],
"imagebinds": ["https://github.com/sshh12/multi_token/raw/main/.demo/imagebind-dog-audio.wav"],
},
).json()
# {'output': '这个声音中的动物是狗。'}
长文档
python scripts/serve_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_lora_path sshh12/Mistral-7B-LoRA-DocumentGTE-260K-x128 \
--port 7860
from multi_token.modalities.document_gte import (
split_text_into_documents,
)
with open(".demo/llava-paper.txt", "r") as f:
docs = split_text_into_documents(f.read())
requests.post(
"http://localhost:7860/generate",
json={
"messages": [{"role": "user", "content": "阅读论文" + "<document>" * len(docs) + "。给我一个摘要。"}],
"documents": docs,
},
).json()
# {'output': '以下是该论文的主要观点摘要:\n\n- 该论文提出了一个名为LAML的新数据集,包含100,000对图像-文本对,涵盖100种不同语言。该数据集旨在为训练多语言视觉-语言模型提供大规模资源。\n\n- 作者发现现有的多语言视觉-语言模型在为之前未见过的语言生成高质量图像说明时表现不佳。这是因为这些模型缺乏生成特定语言知识的能力...'}
训练
添加一种模态
你可以通过实现multi_token.modalities.base_modality.Modality
的实例来完成此操作(参见视觉CLIP示例)。
查看带注释的示例
class MyModality(Modality):
def __init__(
self,
):
# ...
def build_projector(self, lm_hidden_size: int) -> nn.Module:
# 一个pytorch模块,将预处理后的项目(在`forward`之后)转换为张量`(批量大小 x 令牌宽度 x lm_hidden_size)`
@property
def name(self) -> str:
# 此模态的名称/ID
return "my_modality"
@property
def token(self) -> str:
# 你将在文本中用来表示此模态的令牌
return "<my-modality>"
@property
def data_key(self) -> str:
# 数据集行中原始实例的键
return "my_modality_items"
@property
def token_width(self) -> int:
# 我们应该使用多少个令牌来表示此模态的实例?
# 太小则描述不够,太大则会占用上下文窗口
return 1
def preprocess_rows(self, row: List[Dict]) -> List[Optional[Any]]:
# 将原始数据集行转换为任意张量以传递给`forward`
@torch.no_grad()
def forward(self, encoded_values: List[Any]) -> List[torch.Tensor]:
# 将`preprocess_rows`输出值编码为将被输入到投影器的格式
通过将这个新模态添加到multi_token.modalities.MODALITY_BUILDERS
来注册它。
MODALITY_BUILDERS = {
...,
"my_modality": lambda: [MyModality()],
}
数据集
你可以查看一些现有的脚本以了解如何将数据转换为正确的数据集格式。
模式:
// LLaVA/CLIP示例
{
"id": "arbitrary-id-123",
"images": ["/path/to/image.png"],
"messages": [{"role": "user", "content": "描述<image>"}, {"role": "assistant", "content": "这是一个土豆。"}],
}
// 自定义
{
"id": "arbitrary-id-123",
"my_modality_items": ["/path/to/data 或直接是完整文档"],
"messages": [{"role": "user", "content": "描述<my-modality>"}, {"role": "assistant", "content": "这是..."}],
}
然后使用dataset.save_to_disk(output_folder)
保存。
预训练
使用此命令,并带有标准的huggingface训练参数:
deepspeed scripts/train_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_cls MistralLMMForCausalLM \
--modality_builder vision_clip \
--dataset_path /data/llava-chat-captions \
--output_dir /data/output/my_lmm_pretrain \
--pretrain_projectors \
--lora_enable True \
--bf16 True \
--tf32 True \
--num_train_epochs 1 \
--gradient_checkpointing True \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 32 \
--model_max_length 2048 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--dataloader_num_workers 2 \
--logging_steps 1 \
--report_to wandb \
--deepspeed ./configs/zero2.json
关键参数包括:
--modality_builder
:要使用的模态构建器的名称(参见MODALITY_BUILDERS
)--pretrain_projectors
:冻结语言模型,仅训练投影器--model_cls
:要使用的模型类(这应与你的基础模型匹配)
微调
使用以下命令和标准的huggingface训练参数:
deepspeed scripts/train_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_cls MistralLMMForCausalLM \
--modality_builder vision_clip \
--pretrained_projectors_path /data/output/my_lmm_pretrain/checkpoint-4000/non_lora_trainables.bin \
--dataset_path /data/llava-chat-captions \
--output_dir /data/output/my_lmm_pretrain \
--pretrain_projectors \
--lora_enable True \
--bf16 True \
--tf32 True \
--num_train_epochs 1 \
--gradient_checkpointing True \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 32 \
--model_max_length 2048 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--dataloader_num_workers 2 \
--logging_steps 1 \
--report_to wandb \
--deepspeed ./configs/zero2.json
关键参数是:
--modality_builder
:要使用的模态构建器的名称(参见MODALITY_BUILDERS
)--pretrained_projectors_path
:预训练投影器的路径(来自预训练步骤)--model_cls
:要使用的模型类(这应该与你的基础模型匹配)
你也可以省略pretrained_projectors_path
来从头开始训练完整模型。根据LLaVA论文,这不如先训练投影器好(但它会有效)。
推理
使用以下命令运行本地flask服务器进行推理:
python scripts/serve_model.py \
--model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
--model_lora_path /data/output/lmm_just_trained_folder \
--port 7860
你可以使用这个工具将模型上传到huggingface:
python scripts/upload_model.py \
-r username/my-new-lmm \
-m /data/output/lmm_just_trained_folder
与LLaVA的比较
LLaVA:大型语言和视觉助手
这个项目的灵感和大部分源代码来自原始的LLaVA实现(apache 2.0)。
核心差异
- 这个库设计得更加模块化,用于添加自定义编码器/投影器。在某些方面,LLaVA的实现被简化了(例如,去掉了大量的评估、预处理代码和非LLAMA部分),而在其他方面则更复杂(处理多种类型的模态)。
- 将投影编码注入语言模型的token空间的标记化和注入过程是从头编写的,但在理论上做的是完全相同的事情。我认为这个项目的
prepare_inputs_labels_for_multimodal
比原始版本更容易理解和操作。 - 你可以使用来自相同或不同模态的多个token实例(而LLaVA只用于单个图像)。例如,
Given <image> and <image>, answer the question asked in <audio>
。
如果使用这个库训练一个模型,使用与LLaVA-1.5相同的基础模型和投影配置,我预计性能会几乎相同(除非这个实现中有任何bug)。
待办事项
- 多GPU支持
- 完整(非LoRA)训练
- 训练量化(QLoRA)
- 高效批处理预处理
- 高效批处理投影
- 高效批处理整理(基于示例长度)
- 高效批处理推理
- 允许非
INST
基础的指令格式和系统token - 支持更多基础语言模型
开发
Windows Docker 开发环境
我的本地开发环境是 Windows + WSL + Docker + 3090 Ti(24GB 显存)。F:/
被配置为一个大容量数据驱动器,我在容器间共享它。
docker build -t multi-token-dev .
docker run -it --gpus all -p 7860:7860 --mount type=bind,source=F:/docker-hf-cache,target=/root/.cache/huggingface --mount type=bind,source=F:/docker-data,target=/data --name multi-token-dev multi-token-dev
Vast.ai 开发环境
对于某些模型,我使用 vast.ai 上相对便宜的 GPU 实例。
vastai create instance $ID --image pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel --disk 512
ssh -p $PORT root@$HOST
curl -o- https://raw.githubusercontent.com/sshh12/multi_token/main/scripts/vastai_setup.sh | bash
在训练过程中,我运行:source ./scripts/vastai_sync.sh $INSTANCE_ID
来将输出文件夹同步到我的本地机器。