项目介绍:repvit_m1.dist_in1k
项目背景
RepViT是一个图像分类模型,专注于提升移动设备上的卷积神经网络(CNN)的性能,同时借鉴了视觉Transformer(ViT)的理念。这个模型经过ImageNet-1k数据集的训练,并通过蒸馏技术得到优化,适用于图像分类任务。
模型详情
- 模型类型: 图像分类 / 特征提取骨干网络
- 模型参数:
- 参数数量(百万):5.5
- 计算量(十亿倍操作数,GMACs):0.8
- 活跃参数数量(百万):7.4
- 图像尺寸:224 x 224 像素
- 相关论文:
- RepViT: Revisiting Mobile CNN From ViT Perspective: 论文链接
- 项目主页: GitHub 原始项目
- 数据集: ImageNet-1k
模型使用方法
RepViT模型可以用于多种图像处理任务,如图像分类、特征图提取和图像嵌入。
图像分类
RepViT模型能够高效地对图像进行分类预测。在Python中,可以使用timm
库加载预训练模型,对输入图像进行处理并输出分类结果。代码示例展示了如何获取图像的前五名分类概率和对应的类别索引。
from urllib.request import urlopen
from PIL import Image
import timm
import torch
# 加载图像
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
# 创建和准备模型
model = timm.create_model('repvit_m1.dist_in1k', pretrained=True)
model = model.eval()
# 获取模型特定的转换(规范化,调整大小)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
# 预测
output = model(transforms(img).unsqueeze(0)) # 将单张图像处理为批处理格式(批大小为1)
# 获取前五名预测概率和类别
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
特征图提取
RepViT模型还能提取输入图像的特征图,用于深度学习中的进一步分析和应用。通过设置features_only=True
,可以获取每层输出的特征图。
from urllib.request import urlopen
from PIL import Image
import timm
# 加载图像
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
# 创建和准备模型
model = timm.create_model(
'repvit_m1.dist_in1k',
pretrained=True,
features_only=True,
)
model = model.eval()
# 获取模型特定的转换
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
# 提取特征图
output = model(transforms(img).unsqueeze(0)) # 处理成批格式
# 打印每个特征图的形状
for o in output:
print(o.shape) # 示例输出: torch.Size([1, 384, 7, 7])
图像嵌入
通过移除分类层,RepViT可以用来生成图像嵌入。这样的嵌入可以应用于图像检索等任务。
from urllib.request import urlopen
from PIL import Image
import timm
# 加载图像
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
# 创建和准备模型(移除分类层)
model = timm.create_model(
'repvit_m1.dist_in1k',
pretrained=True,
num_classes=0, # 移除分类层
)
model = model.eval()
# 获取模型特定的转换
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
# 输出图像的嵌入
output = model(transforms(img).unsqueeze(0)) # 输出为(批大小,特征数量)的张量
# 直接获得特征
output = model.forward_features(transforms(img).unsqueeze(0))
output = model.forward_head(output, pre_logits=True)
引用格式
如果在学术论文中使用了RepViT模型,请使用以下引用:
@misc{wang2023repvit,
title={RepViT: Revisiting Mobile CNN From ViT Perspective},
author={Ao Wang and Hui Chen and Zijia Lin and Hengjun Pu and Guiguang Ding},
year={2023},
eprint={2307.09283},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
RepViT模型通过简单而高效的架构设计,在移动设备上实现了优异的图像处理能力,非常适合设备资源受限的场景。