项目介绍:convnext_base.fb_in22k_ft_in1k_384
convnext_base.fb_in22k_ft_in1k_384 是一个用于图像分类的模型。这一模型基于ConvNeXt架构,首先在ImageNet-22k数据集上进行预训练,然后在ImageNet-1k数据集上进行微调,由研究论文的作者开发。
模型详情
- 模型类型: 图像分类/特征骨干网络
- 模型参数:
- 参数量(百万): 88.6
- GMACs: 45.2
- 活动参数量(百万): 84.5
- 图像尺寸: 384 x 384
- 相关论文: A ConvNet for the 2020s
- 项目地址: GitHub
- 数据集: 使用ImageNet-1k进行训练
- 预训练数据集: 使用ImageNet-22k进行预训练
模型使用
图像分类
使用此模型进行图像分类非常简单。代码示例展示了如何下载一张图片,并使用timm库加载预训练模型。
from urllib.request import urlopen
from PIL import Image
import timm
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
model = timm.create_model('convnext_base.fb_in22k_ft_in1k_384', pretrained=True)
model = model.eval()
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0))
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
特征图提取
除此之外,还可以使用该模型提取特征图。这样可以在不包含分类头的情况下,获得特征图,用于进一步的分析。
model = timm.create_model(
'convnext_base.fb_in22k_ft_in1k_384',
pretrained=True,
features_only=True,
)
output = model(transforms(img).unsqueeze(0))
for o in output:
print(o.shape)
图像嵌入
获取图像嵌入的过程类似于特征图提取,只不过去掉了分类层。
model = timm.create_model(
'convnext_base.fb_in22k_ft_in1k_384',
pretrained=True,
num_classes=0,
)
output = model(transforms(img).unsqueeze(0))
模型比较
这个模型在一系列复杂度和性能上都有所优化,与其他类似模型相比有着不俗的表现。模型性能上的一些相关数据可以在timm模型结果中查看。例如,与其他模型的对比中,它在批处理大小为256时的样本处理速度为366.54张图像每秒。
此外,通过这款模型可以轻松实现现代化的图像分类任务,对于希望在大规模数据上训练和微调模型的研究者们来说,它是一个理想的选择。