MambaVision: 一种混合Mamba-Transformer视觉主干网络
MambaVision官方PyTorch实现:MambaVision: 一种混合Mamba-Transformer视觉主干网络。
如需商业咨询,请访问我们的网站并提交表单:NVIDIA研究许可
MambaVision在Top-1准确率和吞吐量方面达到了新的最优帕累托前沿,展现出强大的性能。
我们通过创建一个不带SSM的对称路径来引入新的混合块,以增强全局上下文的建模:
MambaVision具有层次化架构,同时使用自注意力和混合块:
💥 新闻 💥
-
[2024.07.24] MambaVision Hugging Face模型已发布!
-
[2024.07.14] 我们增加了对任意分辨率图像处理的支持。
-
[2024.07.12] 论文现已在arXiv上发布!
-
[2024.07.11] Mambavision pip包已发布!
-
[2024.07.10] 我们发布了Mambavision的代码和模型检查点!
快速开始
Hugging Face (分类 + 特征提取)
预训练的MambaVision模型可以通过Hugging Face库用几行代码简单使用。首先安装依赖:
pip install mambavision
可以简单地导入模型:
>>> from transformers import AutoModelForImageClassification
>>> model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
我们在下面演示一个端到端的图像分类示例。
给定以下来自COCO数据集验证集的图像作为输入:
可以使用以下代码片段:
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
# 推理时设置为评估模式
model.cuda().eval()
# 准备图像
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224) # MambaVision支持任意输入分辨率
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
# 模型推理
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("预测类别:", model.config.id2label[predicted_class_idx])
预测的标签是棕熊,灰熊,Ursus arctos。
您还可以使用Hugging Face MambaVision模型进行特征提取。该模型提供模型每个阶段的输出(4个阶段的分层多尺度特征)以及经过平均池化和展平的最终特征。前者用于下游任务,如分类和检测。
以下代码片段可用于特征提取:
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-T-1K", trust_remote_code=True)
# 推理时设置为评估模式
model.cuda().eval()
# 准备图像
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224) # MambaVision支持任意输入分辨率
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
# 模型推理
out_avg_pool, features = model(inputs)
print("平均池化特征的大小:", out_avg_pool.size()) # torch.Size([1, 640])
print("提取特征的阶段数:", len(features)) # 4个阶段
print("第1阶段提取特征的大小:", features[0].size()) # torch.Size([1, 80, 56, 56])
print("第4阶段提取特征的大小:", features[3].size()) # torch.Size([1, 640, 7, 7])
目前,我们在Hugging Face上提供MambaVision-T-1K、MambaVision-T2-1K、MambaVision-S-1K、MambaVision-B-1K、MambaVision-L-1K和MambaVision-L2-1K。所有模型也可以在这里查看。
分类 (pip包)
我们还可以用几行代码从pip包中导入预训练的MambaVision模型:
pip install mambavision
可以按如下方式创建具有默认超参数的预训练MambaVision模型:
>>> from mambavision import create_model
# 定义mamba_vision_T模型
>>> model = create_model('mamba_vision_T', pretrained=True, model_path="/tmp/mambavision_tiny_1k.pth.tar")
可用的预训练模型列表包括mamba_vision_T
、mamba_vision_T2
、mamba_vision_S
、mamba_vision_B
、mamba_vision_L
和mamba_vision_L2
。
我们还可以通过传递任意分辨率的虚拟图像来简单测试模型。输出是logits:
>>> import torch
>>> image = torch.rand(1, 3, 512, 224).cuda() # 将图像放在cuda上
>>> model = model.cuda() # 将模型放在cuda上
>>> output = model(image) # 输出logit大小为[1, 1000]
使用我们pip包中的预训练模型,您可以简单地运行验证:
python validate_pip_model.py --model mamba_vision_T --data_dir=$DATA_PATH --batch-size $BS
常见问题
- MambaVision支持处理任意输入分辨率的图像吗?
是的!您可以传递任意分辨率的图像,无需更改模型。
-
我可以将MambaVision应用于检测、分割等下游任务吗? 是的!我们正在努力尽快发布。但是使用MambaVision主干网络进行这些任务与
mmseg
或mmdet
包中的其他模型非常相似。此外,MambaVision Hugging Face模型提供特征提取功能,可用于下游任务。请参见上面的例子。 -
我有兴趣在自己的仓库中重新实现MambaVision。我们可以使用预训练权重吗?
可以!预训练权重在CC-BY-NC-SA-4.0许可下发布。请在此仓库中提交一个issue,我们将把您的仓库添加到我们代码库的README中,并适当地致谢您的努力。
结果 + 预训练模型
ImageNet-1K
MambaVision ImageNet-1K 预训练模型
<表格内容>
安装
我们提供了一个docker文件。此外,假设已安装最新的PyTorch包,可以通过运行以下命令安装依赖项:
pip install -r requirements.txt
评估
可以使用以下方法在ImageNet-1K验证集上评估MambaVision模型:
python validate.py \
--model <模型名称>
--checkpoint <检查点路径>
--data_dir <imagenet路径>
--batch-size <每个gpu的批量大小>
这里--model
是MambaVision变体(例如mambavision_tiny_1k
),--checkpoint
是预训练模型权重的路径,--data_dir
是ImageNet-1K验证集的路径,--batch-size
是批量大小。我们还在这里提供了一个示例脚本。
引用
如果您发现MambaVision对您的工作有用,请考虑引用我们的论文:
@article{hatamizadeh2024mambavision,
title={MambaVision: A Hybrid Mamba-Transformer Vision Backbone},
author={Hatamizadeh, Ali and Kautz, Jan},
journal={arXiv preprint arXiv:2407.08083},
year={2024}
}
Star历史
<图表内容>
许可证
版权所有 © 2024, NVIDIA Corporation。保留所有权利。
本作品根据NVIDIA Source Code License-NC提供。点击这里查看此许可证的副本。
预训练模型在CC-BY-NC-SA-4.0下共享。如果您重新混合、转换或基于该材料进行创作,您必须在相同的许可证下分发您的贡献。
有关timm仓库的许可信息,请参阅其仓库。
有关ImageNet数据集的许可信息,请参阅ImageNet官方网站。
致谢
这个仓库是建立在timm仓库之上的。我们感谢Ross Wrightman创建和维护这个高质量的库。