Project Icon

multi_token

将多模态嵌入到大语言模型的开源框架

multi_token是一个开源项目,旨在扩展大语言模型的多模态处理能力。该框架支持将图像、音频、文档和视频等多种模态编码为统一格式,并嵌入到单个模型中。它提供了简便的实现方法,使开发者能够轻松构建支持长文档、图像、音频和视频等多模态输入的语言模型。

multi_token

将任意模态(图像、音频、文档等)嵌入到大型语言模型中。

这个库旨在作为LLaVA的扩展,用于将✨任何东西✨(图像、声音、文档、视频、动作捕捉、屏幕截图、语音录音等)编码成可用于大型语言模型的格式。它的主要贡献在于能够将多个实例和模态嵌入到单个模型中,并提供了一个相对简单的框架来实现这一目标。

利用这个库,你可能可以向大型多模态模型(LMMs)提出以下问题:

  • 阅读<文档>并给我一个摘要。

  • 听<音频>并回答口头提出的问题。

  • 比较和对比<图像>和<图像>

  • 根据<屏幕截图>和<游戏状态>,我应该按哪个键?

想了解这是如何工作的吗?请查看这篇博客文章

使用方法

git clone https://github.com/sshh12/multi_token \
        && cd multi_token \
        && pip install -r requirements.txt \
        && pip install -e .

pip install flash-attn --no-build-isolation

模型库

⚠️ 如果遇到缺少adapters.bin的问题,请参见 https://github.com/sshh12/multi_token/issues/12 ⚠️

基础模型模型模态备注
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-LoRA-DocumentGTE-16K-x8长文档

将文档编码为一系列<document>并使用documents
⚠️📚 在维基百科上预训练并在LongAlpaca和Long-Data-Collections上微调的压缩模型。使用gte-large将512个token的块压缩为64个,结果可能会有较大损失。其性能与x128版本相似,表明瓶颈可能是嵌入模型本身。

计算资源:约100个A6000 GPU小时
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-LoRA-DocumentGTE-260K-x128长文档

将文档编码为一系列<document>并使用documents
⚠️📚 在维基百科上预训练并在LongAlpaca和Long-Data-Collections上微调的压缩模型。使用gte-large将512个token的块压缩为仅4个,结果可能会有较大损失。

计算资源:约50个A6000 GPU小时
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-LoRA-ImageBind-LLAVAImageBind(视觉/音频/文本)

将音频或图像文件名编码为<imagebind>并使用imagebinds
⚠️🖼️🔊📚 在增强的LLaVA数据集上预训练和微调的模型。可能会从音频中幻想出颜色,需要明确提及输入是声音/图像/文档。

计算资源:约180个4090 GPU小时
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-LoRA-VisionCLIP-LLAVA视觉

将图像编码为<image>并使用images
⭐🖼️ 在LLaVA数据集上预训练和微调的模型。应该可与BakLLaVALLaVA 1.5相媲美。

计算资源:约160个3090 Ti GPU小时
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-LoRA-VisionCLIPPool-LLAVA视觉

将图像编码为<image>并使用images
⭐🖼️ 在LLaVA数据集上预训练和微调的模型。应该可与BakLLaVALLaVA 1.5相媲美。使用CLIP的最后一层编码为10个token(而不是原始的576个)。

计算资源:约100个A6000 GPU小时
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-LoRA-Multi-VisionCLIPPool-LLAVA视觉

将图像编码为 <image><image>... 并使用 images
⭐🖼️🖼️ 一个在LLaVA数据集和合成多图像数据集上预训练和微调的模型。每张图像编码为10个标记,最多支持6张图像。

计算量:约100个A6000 GPU小时
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-CLIP-LoRA-captions-only-demo视觉

将图像编码为 <image> 并使用 images
⚠️🖼️ 这是一个__非常有限__的图像模型,仅在少量__仅包含描述__的示例上训练,目的是展示概念验证。

计算量:约10个3090 Ti GPU小时
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-LoRA-XCLIP视频

将视频编码为 <video> 并使用 videos
⚠️🎥 这是一个__非常有限__的视频模型。很难找到好的视频描述数据集,所以这个模型训练不足。

计算量:约50个A6000 GPU小时
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-LoRA-AudioWhisper音频(语音)

将音频编码为 <speech> 并使用 speech_audios
⚠️🔊 一个在commonvoice上预训练并在GPT3.5合成数据集上微调的模型。这个模型训练不足,效果不是很好(也基于whisper-small),但勉强可用。

计算量:约60个A6000 GPU小时
mistralai/Mistral-7B-Instruct-v0.1sshh12/Mistral-7B-LoRA-AudioCLAP音频(声音)

将音频编码为 <sound> 并使用 sounds
⚠️🔊 一个在 Chr0my/Epidemic_sounds 上预训练并在GPT3.5合成数据集上微调的模型。这个模型训练不足,但效果似乎还可以。

计算量:约30个A6000 GPU小时

视觉

LLaVA等效
python scripts/serve_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_lora_path sshh12/Mistral-7B-LoRA-VisionCLIP-LLAVA \
    --load_bits 4 \
    --port 7860
requests.post(
    "http://localhost:7860/generate",
    json={
        "messages": [{"role": "user", "content": "我在参观这个地方时应该注意什么?<image>"}],
        "images": ["https://github.com/sshh12/multi_token/raw/main/.demo/llava-view.jpg"],
    },
).json()
# {'output': '在参观这个有木质码头的湖泊时,有几点需要注意。首先,要注意水深和是否有隐藏的障碍物,如岩石或水下碎片,这些可能会威胁到您的安全。其次,要留意天气状况,因为天气的突然变化可能会使水域变得不可预测和潜在危险。最后,要警惕该地区的任何野生动物或海洋生物,它们可能会威胁到您的安全或损坏码头。'}
多图像
python scripts/serve_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_lora_path sshh12/Mistral-7B-LoRA-Multi-VisionCLIPPool-LLAVA \
    --port 7860
requests.post(
    "http://localhost:7860/generate",
    json={
        "messages": [{"role": "user", "content": "<image><image> 这两张图片在颜色上有什么区别?"}],
        "images": ["https://github.com/sshh12/multi_token/raw/main/.demo/wiki-pink-flower.jpg", "https://github.com/sshh12/multi_token/raw/main/.demo/wiki-yellow-flower.jpg"],
    },
).json()
# {'output': '第一张图片是粉色花朵,而第二张图片是黄色花朵。'}

语音

python scripts/serve_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_lora_path sshh12/Mistral-7B-LoRA-AudioWhisper \
    --port 7860
requests.post(
    "http://localhost:7860/generate",
    json={
        "messages": [{"role": "user", "content": "说的是什么?<speech>"}],
        "speech_audios": ["https://github.com/sshh12/multi_token/raw/main/.demo/test.mp3"],
    },
).json()
# {'output': '这是一个测试。'}

声音

python scripts/serve_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_lora_path sshh12/Mistral-7B-LoRA-AudioCLAP \
    --port 7860
requests.post(
    "http://localhost:7860/generate",
    json={
        "messages": [{"role": "user", "content": "是什么在发出这个声音?<sound>"}],
        "sounds": ["https://github.com/sshh12/multi_token/raw/main/.demo/imagebind-dog-audio.wav"],
    },
).json()
# {'output': '这个声音是吉娃娃在叫。'}

视频

python scripts/serve_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_lora_path sshh12/Mistral-7B-LoRA-XCLIP \
    --port 7860
requests.post(
    "http://localhost:7860/generate",
    json={
        "messages": [{"role": "user", "content": "<video> 视频中展示了什么乐器?"}],
        "videos": ["https://www.youtube.com/watch?v=3569sBBgVsc"],
    },
).json()
# {'output': '一个男人在房间里弹钢琴'}

ImageBind(视觉/音频/文本)

python scripts/serve_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_lora_path sshh12/Mistral-7B-LoRA-ImageBind-LLAVA \
    --port 7860
requests.post(
    "http://localhost:7860/generate",
    json={
        "messages": [{"role": "user", "content": "<imagebind> 这个声音中的动物是什么?"}],
        "imagebinds": ["https://github.com/sshh12/multi_token/raw/main/.demo/imagebind-dog-audio.wav"],
    },
).json()
# {'output': '这个声音中的动物是狗。'}

长文档

python scripts/serve_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_lora_path sshh12/Mistral-7B-LoRA-DocumentGTE-260K-x128 \
    --port 7860
from multi_token.modalities.document_gte import (
    split_text_into_documents,
)

with open(".demo/llava-paper.txt", "r") as f:
    docs = split_text_into_documents(f.read())

requests.post(
    "http://localhost:7860/generate",
    json={
        "messages": [{"role": "user", "content": "阅读论文" + "<document>" * len(docs) + "。给我一个摘要。"}],
        "documents": docs,
    },
).json()
# {'output': '以下是该论文的主要观点摘要:\n\n- 该论文提出了一个名为LAML的新数据集,包含100,000对图像-文本对,涵盖100种不同语言。该数据集旨在为训练多语言视觉-语言模型提供大规模资源。\n\n- 作者发现现有的多语言视觉-语言模型在为之前未见过的语言生成高质量图像说明时表现不佳。这是因为这些模型缺乏生成特定语言知识的能力...'}

训练

添加一种模态

你可以通过实现multi_token.modalities.base_modality.Modality的实例来完成此操作(参见视觉CLIP示例)。

查看带注释的示例
class MyModality(Modality):
    def __init__(
        self,
    ):
        # ...

    def build_projector(self, lm_hidden_size: int) -> nn.Module:
        # 一个pytorch模块,将预处理后的项目(在`forward`之后)转换为张量`(批量大小 x 令牌宽度 x lm_hidden_size)`

    @property
    def name(self) -> str:
        # 此模态的名称/ID
        return "my_modality"

    @property
    def token(self) -> str:
        # 你将在文本中用来表示此模态的令牌
        return "<my-modality>"

    @property
    def data_key(self) -> str:
        # 数据集行中原始实例的键
        return "my_modality_items"

    @property
    def token_width(self) -> int:
        # 我们应该使用多少个令牌来表示此模态的实例?
        # 太小则描述不够,太大则会占用上下文窗口
        return 1

    def preprocess_rows(self, row: List[Dict]) -> List[Optional[Any]]:
        # 将原始数据集行转换为任意张量以传递给`forward`

    @torch.no_grad()
    def forward(self, encoded_values: List[Any]) -> List[torch.Tensor]:
        # 将`preprocess_rows`输出值编码为将被输入到投影器的格式

通过将这个新模态添加到multi_token.modalities.MODALITY_BUILDERS来注册它。

MODALITY_BUILDERS = {
    ...,
    "my_modality": lambda: [MyModality()],
}

数据集

你可以查看一些现有的脚本以了解如何将数据转换为正确的数据集格式。

模式:

// LLaVA/CLIP示例
{
    "id": "arbitrary-id-123",
    "images": ["/path/to/image.png"],
    "messages": [{"role": "user", "content": "描述<image>"}, {"role": "assistant", "content": "这是一个土豆。"}],
}

// 自定义
{
    "id": "arbitrary-id-123",
    "my_modality_items": ["/path/to/data 或直接是完整文档"],
    "messages": [{"role": "user", "content": "描述<my-modality>"}, {"role": "assistant", "content": "这是..."}],
}

然后使用dataset.save_to_disk(output_folder)保存。

预训练

使用此命令,并带有标准的huggingface训练参数:

deepspeed scripts/train_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_cls MistralLMMForCausalLM \
    --modality_builder vision_clip \
    --dataset_path /data/llava-chat-captions \
    --output_dir /data/output/my_lmm_pretrain \
    --pretrain_projectors \
    --lora_enable True \
    --bf16 True \
    --tf32 True \
    --num_train_epochs 1 \
    --gradient_checkpointing True \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 32 \
    --model_max_length 2048 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 1 \
    --learning_rate 1e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --dataloader_num_workers 2 \
    --logging_steps 1 \
    --report_to wandb \
    --deepspeed ./configs/zero2.json

关键参数包括:

  • --modality_builder:要使用的模态构建器的名称(参见MODALITY_BUILDERS
  • --pretrain_projectors:冻结语言模型,仅训练投影器
  • --model_cls:要使用的模型类(这应与你的基础模型匹配)

微调

使用以下命令和标准的huggingface训练参数:

deepspeed scripts/train_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_cls MistralLMMForCausalLM \
    --modality_builder vision_clip \
    --pretrained_projectors_path /data/output/my_lmm_pretrain/checkpoint-4000/non_lora_trainables.bin \
    --dataset_path /data/llava-chat-captions \
    --output_dir /data/output/my_lmm_pretrain \
    --pretrain_projectors \
    --lora_enable True \
    --bf16 True \
    --tf32 True \
    --num_train_epochs 1 \
    --gradient_checkpointing True \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 32 \
    --model_max_length 2048 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 1 \
    --learning_rate 1e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --dataloader_num_workers 2 \
    --logging_steps 1 \
    --report_to wandb \
    --deepspeed ./configs/zero2.json

关键参数是:

  • --modality_builder:要使用的模态构建器的名称(参见MODALITY_BUILDERS
  • --pretrained_projectors_path:预训练投影器的路径(来自预训练步骤)
  • --model_cls:要使用的模型类(这应该与你的基础模型匹配)

你也可以省略pretrained_projectors_path来从头开始训练完整模型。根据LLaVA论文,这不如先训练投影器好(但它会有效)。

推理

使用以下命令运行本地flask服务器进行推理:

python scripts/serve_model.py \
    --model_name_or_path mistralai/Mistral-7B-Instruct-v0.1 \
    --model_lora_path /data/output/lmm_just_trained_folder \
    --port 7860

你可以使用这个工具将模型上传到huggingface:

python scripts/upload_model.py \
    -r username/my-new-lmm \
    -m /data/output/lmm_just_trained_folder

与LLaVA的比较

LLaVA:大型语言和视觉助手

[项目页面] [演示] [数据] [模型库]

通过视觉指令微调改进基线 [论文]
刘浩天, 李春元, 李宇恒, 李容在

视觉指令微调(NeurIPS 2023,口头报告)[论文]
刘浩天*, 李春元*, 吴清阳, 李容在 (*共同贡献)

这个项目的灵感和大部分源代码来自原始的LLaVA实现(apache 2.0)。

核心差异

  • 这个库设计得更加模块化,用于添加自定义编码器/投影器。在某些方面,LLaVA的实现被简化了(例如,去掉了大量的评估、预处理代码和非LLAMA部分),而在其他方面则更复杂(处理多种类型的模态)。
  • 将投影编码注入语言模型的token空间的标记化和注入过程是从头编写的,但在理论上做的是完全相同的事情。我认为这个项目的prepare_inputs_labels_for_multimodal比原始版本更容易理解和操作。
  • 你可以使用来自相同或不同模态的多个token实例(而LLaVA只用于单个图像)。例如,Given <image> and <image>, answer the question asked in <audio>

如果使用这个库训练一个模型,使用与LLaVA-1.5相同的基础模型和投影配置,我预计性能会几乎相同(除非这个实现中有任何bug)。

待办事项

  • 多GPU支持
  • 完整(非LoRA)训练
  • 训练量化(QLoRA)
  • 高效批处理预处理
  • 高效批处理投影
  • 高效批处理整理(基于示例长度)
  • 高效批处理推理
  • 允许非INST基础的指令格式和系统token
  • 支持更多基础语言模型

开发

Windows Docker 开发环境

我的本地开发环境是 Windows + WSL + Docker + 3090 Ti(24GB 显存)。F:/ 被配置为一个大容量数据驱动器,我在容器间共享它。

  1. docker build -t multi-token-dev .
  2. docker run -it --gpus all -p 7860:7860 --mount type=bind,source=F:/docker-hf-cache,target=/root/.cache/huggingface --mount type=bind,source=F:/docker-data,target=/data --name multi-token-dev multi-token-dev

Vast.ai 开发环境

对于某些模型,我使用 vast.ai 上相对便宜的 GPU 实例。

  1. vastai create instance $ID --image pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel --disk 512
  2. ssh -p $PORT root@$HOST
  3. curl -o- https://raw.githubusercontent.com/sshh12/multi_token/main/scripts/vastai_setup.sh | bash

在训练过程中,我运行:source ./scripts/vastai_sync.sh $INSTANCE_ID 来将输出文件夹同步到我的本地机器。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号