项目介绍:mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k
项目背景
mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k是一种用于图像分类的模型,它基于MobileNet-V4架构,并通过平均池化反锯齿技术进行了优化。该模型由Ross Wightman预训练于ImageNet-12k数据集,并在ImageNet-1k数据集上进行了微调。MobileNet-V4是一种通用的移动生态系统模型,适合在移动设备上执行。
模型详情
- 模型类别: 图像分类/特征骨干
- 模型参数:
- 参数数量(百万):32.6
- GMACs(十亿次乘加运算):9.6
- 激活数量(百万):43.9
- 图片大小:训练时448x448,测试时544x544
- 数据集:
- 微调数据集:ImageNet-1k
- 预训练数据集:ImageNet-12k
- 相关论文:
模型用途
图像分类
该模型可以应用于图像分类。使用timm
库加载预训练的MobileNet-V4模型,并对输入的图像进行预测。模型的输出包含预测类别的概率分布,可以用于识别和分类物体。
from urllib.request import urlopen
from PIL import Image
import timm
img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))
model = timm.create_model('mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k', pretrained=True)
model = model.eval()
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0))
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
特征图提取
除了分类,该模型还可以用于提取特征图,这在图像分析和其他计算机视觉任务中非常有用。提取的特征图可以反映图像的局部和全局信息。
from urllib.request import urlopen
from PIL import Image
import timm
img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))
model = timm.create_model(
'mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k',
pretrained=True,
features_only=True,
)
model = model.eval()
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0))
for o in output:
print(o.shape)
图像嵌入
此外,模型支持图像嵌入输出,这对图像检索和相似性计算等应用非常有帮助。使用嵌入可以对图像进行高效的特征表示。
from urllib.request import urlopen
from PIL import Image
import timm
img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))
model = timm.create_model(
'mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k',
pretrained=True,
num_classes=0,
)
model = model.eval()
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0))
模型对比
通过比较不同的MobileNetV4模型版本,可以发现图像的分类准确率(top1和top5)、参数数量、图像大小等差异。在此版本中,该模型实现了84.99%的top1准确率和97.294%的top5准确率,参数数量为32.59百万,测试图片大小为544。
参考文献
如果需要了解更多关于该模型的信息,可以参考以下文献:
- Wightman, R. (2019). PyTorch Image Models. GitHub repository.
- Qin, D., et al. (2024). MobileNetV4 -- Universal Models for the Mobile Ecosystem. arXiv preprint arXiv:2404.10518.