Project Icon

segmentation_models.pytorch

基于PyTorch的神经网络图像分割库

segmentation_models.pytorch 是一个基于 PyTorch 的图像分割库,提供9种分割模型架构和124种编码器。该库 API 简洁,支持预训练权重,并包含常用评估指标和损失函数。它适用于研究和实际应用中的各种图像分割任务,是图像分割领域的实用工具。

logo
基于PyTorch的图像分割神经网络Python库

Generic badge GitHub Workflow Status (branch) Read the Docs
PyPI PyPI - Downloads
PyTorch - Version Python - Version

该库的主要特点包括:

  • 高级API(只需两行代码即可创建神经网络)
  • 9种用于二分类和多分类分割的模型架构(包括著名的Unet)
  • 124种可用编码器(以及来自timm的500多种编码器)
  • 所有编码器都有预训练权重,以实现更快更好的收敛
  • 用于训练的常用指标和损失函数

📚 项目文档 📚

访问Read The Docs项目页面或阅读以下README,了解更多关于Segmentation Models Pytorch(简称SMP)库的信息

📋 目录

  1. 快速开始
  2. 示例
  3. 模型
    1. 架构
    2. 编码器
    3. Timm编码器
  4. 模型API
    1. 输入通道
    2. 辅助分类输出
    3. 深度
  5. 安装
  6. 使用该库获胜的比赛
  7. 贡献
  8. 引用
  9. 许可证

⏳ 快速开始

1. 使用SMP创建你的第一个分割模型

分割模型只是一个PyTorch nn.Module,可以像这样轻松创建:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # 选择编码器,例如 mobilenet_v2 或 efficientnet-b7
    encoder_weights="imagenet",     # 使用`imagenet`预训练权重进行编码器初始化
    in_channels=1,                  # 模型输入通道数(1用于灰度图像,3用于RGB等)
    classes=3,                      # 模型输出通道数(数据集中的类别数)
)
  • 查看表格了解可用的模型架构
  • 查看表格了解可用的编码器及其对应的权重

2. 配置数据预处理

所有编码器都有预训练权重。以与权重预训练期间相同的方式准备数据可能会给你带来更好的结果(更高的指标分数和更快的收敛)。如果你训练整个模型而不仅仅是解码器,这不是必需的

from segmentation_models_pytorch.encoders import get_preprocessing_fn

preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')

恭喜!你已完成!现在你可以使用你喜欢的框架来训练你的模型了!

💡 示例

📦 模型

架构

编码器

以下是SMP支持的编码器列表。选择适当的编码器系列并点击展开表格,选择特定的编码器及其预训练权重(encoder_nameencoder_weights参数)。

ResNet
编码器权重参数量, M
resnet18imagenet / ssl / swsl11M
resnet34imagenet21M
resnet50imagenet / ssl / swsl23M
resnet101imagenet42M
resnet152imagenet58M
ResNeXt
编码器权重参数量, M
resnext50_32x4dimagenet / ssl / swsl22M
resnext101_32x4dssl / swsl42M
resnext101_32x8dimagenet / instagram / ssl / swsl86M
resnext101_32x16dinstagram / ssl / swsl191M
resnext101_32x32dinstagram466M
resnext101_32x48dinstagram826M
ResNeSt
编码器权重参数量, M
timm-resnest14dimagenet8M
timm-resnest26dimagenet15M
timm-resnest50dimagenet25M
timm-resnest101eimagenet46M
timm-resnest200eimagenet68M
timm-resnest269eimagenet108M
timm-resnest50d_4s2x40dimagenet28M
timm-resnest50d_1s4x24dimagenet23M
Res2Ne(X)t
编码器权重参数量, M
timm-res2net50_26w_4simagenet23M
timm-res2net101_26w_4simagenet43M
timm-res2net50_26w_6simagenet35M
timm-res2net50_26w_8simagenet46M
timm-res2net50_48w_2simagenet23M
timm-res2net50_14w_8simagenet23M
timm-res2next50imagenet22M
RegNet(x/y)
|编码器 |权重 |参数量(百万) | |--------------------------------|:------------------------------:|:------------------------------:| |timm-regnetx_002 |imagenet |2M | |timm-regnetx_004 |imagenet |4M | |timm-regnetx_006 |imagenet |5M | |timm-regnetx_008 |imagenet |6M | |timm-regnetx_016 |imagenet |8M | |timm-regnetx_032 |imagenet |14M | |timm-regnetx_040 |imagenet |20M | |timm-regnetx_064 |imagenet |24M | |timm-regnetx_080 |imagenet |37M | |timm-regnetx_120 |imagenet |43M | |timm-regnetx_160 |imagenet |52M | |timm-regnetx_320 |imagenet |105M | |timm-regnety_002 |imagenet |2M | |timm-regnety_004 |imagenet |3M | |timm-regnety_006 |imagenet |5M | |timm-regnety_008 |imagenet |5M | |timm-regnety_016 |imagenet |10M | |timm-regnety_032 |imagenet |17M | |timm-regnety_040 |imagenet |19M | |timm-regnety_064 |imagenet |29M | |timm-regnety_080 |imagenet |37M | |timm-regnety_120 |imagenet |49M | |timm-regnety_160 |imagenet |80M | |timm-regnety_320 |imagenet |141M |
GERNet
编码器权重参数量(百万)
timm-gernet_simagenet6M
timm-gernet_mimagenet18M
timm-gernet_limagenet28M
SE-Net
编码器权重参数量(百万)
senet154imagenet113M
se_resnet50imagenet26M
se_resnet101imagenet47M
se_resnet152imagenet64M
se_resnext50_32x4dimagenet25M
se_resnext101_32x4dimagenet46M
SK-ResNe(X)t
编码器权重参数量(百万)
timm-skresnet18imagenet11M
timm-skresnet34imagenet21M
timm-skresnext50_32x4dimagenet25M
DenseNet
编码器权重参数量(百万)
densenet121imagenet6M
densenet169imagenet12M
densenet201imagenet18M
densenet161imagenet26M
Inception
编码器权重参数量(百万)
inceptionresnetv2imagenet / imagenet+background54M
inceptionv4imagenet / imagenet+background41M
xceptionimagenet22M
EfficientNet
编码器权重参数量(百万)
efficientnet-b0imagenet4M
efficientnet-b1imagenet6M
efficientnet-b2imagenet7M
efficientnet-b3imagenet10M
efficientnet-b4imagenet17M
efficientnet-b5imagenet28M
efficientnet-b6imagenet40M
efficientnet-b7imagenet63M
timm-efficientnet-b0imagenet / advprop / noisy-student4M
timm-efficientnet-b1imagenet / advprop / noisy-student6M
timm-efficientnet-b2imagenet / advprop / noisy-student7M
timm-efficientnet-b3imagenet / advprop / noisy-student10M
timm-efficientnet-b4imagenet / advprop / noisy-student17M
timm-efficientnet-b5imagenet / advprop / noisy-student28M
timm-efficientnet-b6imagenet / advprop / noisy-student40M
timm-efficientnet-b7imagenet / advprop / noisy-student63M
timm-efficientnet-b8imagenet / advprop84M
timm-efficientnet-l2noisy-student474M
timm-efficientnet-lite0imagenet4M
timm-efficientnet-lite1imagenet5M
timm-efficientnet-lite2imagenet6M
timm-efficientnet-lite3imagenet8M
timm-efficientnet-lite4imagenet13M
MobileNet
|编码器 |权重 |参数量(百万) | |--------------------------------|:------------------------------:|:------------------------------:| |mobilenet_v2 |imagenet |2M | |timm-mobilenetv3_large_075 |imagenet |1.78M | |timm-mobilenetv3_large_100 |imagenet |2.97M | |timm-mobilenetv3_large_minimal_100|imagenet |1.41M | |timm-mobilenetv3_small_075 |imagenet |0.57M | |timm-mobilenetv3_small_100 |imagenet |0.93M | |timm-mobilenetv3_small_minimal_100|imagenet |0.43M |
DPN
编码器权重参数量(百万)
dpn68imagenet11M
dpn68bimagenet+5k11M
dpn92imagenet+5k34M
dpn98imagenet58M
dpn107imagenet+5k84M
dpn131imagenet76M
VGG
编码器权重参数量(百万)
vgg11imagenet9M
vgg11_bnimagenet9M
vgg13imagenet9M
vgg13_bnimagenet9M
vgg16imagenet14M
vgg16_bnimagenet14M
vgg19imagenet20M
vgg19_bnimagenet20M
混合视觉Transformer

SegFormer在Imagenet上预训练的骨干网络!可以与包中的其他解码器一起使用,您可以将混合视觉Transformer与Unet、FPN等结合使用!

限制:

  • 编码器支持Linknet、Unet++
  • 编码器仅支持FPN,且仅限编码器深度为5的情况
编码器权重参数量(百万)
mit_b0imagenet3M
mit_b1imagenet13M
mit_b2imagenet24M
mit_b3imagenet44M
mit_b4imagenet60M
mit_b5imagenet81M
MobileOne

苹果公司的"亚毫秒级"骨干网络在Imagenet上预训练!可以与所有解码器一起使用。

注意:在官方GitHub仓库中,s0变体有额外的num_conv_branches,导致比s1有更多的参数。

编码器权重参数量(百万)
mobileone_s0imagenet4.6M
mobileone_s1imagenet4.0M
mobileone_s2imagenet6.5M
mobileone_s3imagenet8.8M
mobileone_s4imagenet13.6M

* sslswsl - 在ImageNet上进行半监督和弱监督学习(仓库)。

Timm编码器

文档

Pytorch Image Models(简称timm)有许多预训练模型和接口,可以将这些模型用作smp中的编码器,但并非所有模型都受支持

  • 并非所有transformer模型都实现了编码器所需的features_only功能
  • 一些模型的步幅不适合

支持的编码器总数:549

🔁 模型API

  • model.encoder - 预训练的骨干网络,用于提取不同空间分辨率的特征
  • model.decoder - 取决于模型架构(Unet/Linknet/PSPNet/FPN
  • model.segmentation_head - 最后一个块,用于生成所需数量的掩码通道(还包括可选的上采样和激活)
  • model.classification_head - 可选块,在编码器顶部创建分类头
  • model.forward(x) - 顺序地将x通过模型的编码器、解码器和分割头(如果指定,还包括分类头)
输入通道

输入通道参数允许您创建可处理任意数量通道张量的模型。如果您使用来自imagenet的预训练权重 - 第一个卷积的权重将被重用。对于1通道情况,它将是第一个卷积层权重的总和,否则通道将使用如下方式填充权重:new_weight[:, i] = pretrained_weight[:, i % 3],然后用new_weight * 3 / new_in_channels进行缩放。

model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))
辅助分类输出

所有模型都支持aux_params参数,默认设置为None。如果aux_params = None,则不创建分类辅助输出,否则模型不仅产生mask,还产生形状为NClabel输出。分类头由GlobalPooling->Dropout(可选)->Linear->Activation(可选)层组成,可以通过aux_params进行配置,如下所示:

aux_params=dict(
    pooling='avg',             # 'avg'或'max'之一
    dropout=0.5,               # dropout比率,默认为None
    activation='sigmoid',      # 激活函数,默认为None
    classes=4,                 # 定义输出标签的数量
)
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
mask, label = model(x)
深度

深度参数指定编码器中的下采样操作数量,因此如果指定较小的depth,可以使模型更轻量。

model = smp.Unet('resnet34', encoder_depth=4)

🛠 安装

PyPI 版本:

$ pip install segmentation-models-pytorch

从源代码安装最新版本:

$ pip install git+https://github.com/qubvel/segmentation_models.pytorch

🏆 使用本库获得的比赛胜利

Segmentation Models包在图像分割比赛中被广泛使用。 这里你可以找到比赛名称、获胜者姓名以及他们解决方案的链接。

🤝 贡献

安装 SMP

make install_dev  # 创建 .venv,以开发模式安装 SMP

运行测试和代码检查

make fixup         # 使用 Ruff 进行格式化和代码检查

更新编码器表格

make table        # 生成编码器表格并输出到标准输出

📝 引用

@misc{Iakubovskii:2019,
  Author = {Pavel Iakubovskii},
  Title = {Segmentation Models Pytorch},
  Year = {2019},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/qubvel/segmentation_models.pytorch}}
}

🛡️ 许可证

项目在 MIT 许可证 下分发

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号