基于PyTorch的图像分割神经网络Python库
该库的主要特点包括:
- 高级API(只需两行代码即可创建神经网络)
- 9种用于二分类和多分类分割的模型架构(包括著名的Unet)
- 124种可用编码器(以及来自timm的500多种编码器)
- 所有编码器都有预训练权重,以实现更快更好的收敛
- 用于训练的常用指标和损失函数
📚 项目文档 📚
访问Read The Docs项目页面或阅读以下README,了解更多关于Segmentation Models Pytorch(简称SMP)库的信息
📋 目录
⏳ 快速开始
1. 使用SMP创建你的第一个分割模型
分割模型只是一个PyTorch nn.Module,可以像这样轻松创建:
import segmentation_models_pytorch as smp
model = smp.Unet(
encoder_name="resnet34", # 选择编码器,例如 mobilenet_v2 或 efficientnet-b7
encoder_weights="imagenet", # 使用`imagenet`预训练权重进行编码器初始化
in_channels=1, # 模型输入通道数(1用于灰度图像,3用于RGB等)
classes=3, # 模型输出通道数(数据集中的类别数)
)
2. 配置数据预处理
所有编码器都有预训练权重。以与权重预训练期间相同的方式准备数据可能会给你带来更好的结果(更高的指标分数和更快的收敛)。如果你训练整个模型而不仅仅是解码器,这不是必需的。
from segmentation_models_pytorch.encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
恭喜!你已完成!现在你可以使用你喜欢的框架来训练你的模型了!
💡 示例
- 使用Pytorch-Lightning训练宠物二分类分割模型的notebook和
- 在CamVid数据集上训练汽车分割模型在此。
- 使用Catalyst(PyTorch的高级框架)、TTAch(PyTorch的TTA库)和Albumentations(快速图像增强库)训练SMP模型 - 在此
- 使用Pytorch-Lightning框架训练SMP模型 - 在此(由@ternaus提供的服装二分类分割)。
📦 模型
架构
- Unet [论文] [文档]
- Unet++ [论文] [文档]
- MAnet [论文] [文档]
- Linknet [论文] [文档]
- FPN [论文] [文档]
- PSPNet [论文] [文档]
- PAN [论文] [文档]
- DeepLabV3 [论文] [文档]
- DeepLabV3+ [论文] [文档]
编码器
以下是SMP支持的编码器列表。选择适当的编码器系列并点击展开表格,选择特定的编码器及其预训练权重(encoder_name
和encoder_weights
参数)。
ResNet
编码器 | 权重 | 参数量, M |
---|---|---|
resnet18 | imagenet / ssl / swsl | 11M |
resnet34 | imagenet | 21M |
resnet50 | imagenet / ssl / swsl | 23M |
resnet101 | imagenet | 42M |
resnet152 | imagenet | 58M |
ResNeXt
编码器 | 权重 | 参数量, M |
---|---|---|
resnext50_32x4d | imagenet / ssl / swsl | 22M |
resnext101_32x4d | ssl / swsl | 42M |
resnext101_32x8d | imagenet / instagram / ssl / swsl | 86M |
resnext101_32x16d | instagram / ssl / swsl | 191M |
resnext101_32x32d | 466M | |
resnext101_32x48d | 826M |
ResNeSt
编码器 | 权重 | 参数量, M |
---|---|---|
timm-resnest14d | imagenet | 8M |
timm-resnest26d | imagenet | 15M |
timm-resnest50d | imagenet | 25M |
timm-resnest101e | imagenet | 46M |
timm-resnest200e | imagenet | 68M |
timm-resnest269e | imagenet | 108M |
timm-resnest50d_4s2x40d | imagenet | 28M |
timm-resnest50d_1s4x24d | imagenet | 23M |
Res2Ne(X)t
编码器 | 权重 | 参数量, M |
---|---|---|
timm-res2net50_26w_4s | imagenet | 23M |
timm-res2net101_26w_4s | imagenet | 43M |
timm-res2net50_26w_6s | imagenet | 35M |
timm-res2net50_26w_8s | imagenet | 46M |
timm-res2net50_48w_2s | imagenet | 23M |
timm-res2net50_14w_8s | imagenet | 23M |
timm-res2next50 | imagenet | 22M |
RegNet(x/y)
GERNet
编码器 | 权重 | 参数量(百万) |
---|---|---|
timm-gernet_s | imagenet | 6M |
timm-gernet_m | imagenet | 18M |
timm-gernet_l | imagenet | 28M |
SE-Net
编码器 | 权重 | 参数量(百万) |
---|---|---|
senet154 | imagenet | 113M |
se_resnet50 | imagenet | 26M |
se_resnet101 | imagenet | 47M |
se_resnet152 | imagenet | 64M |
se_resnext50_32x4d | imagenet | 25M |
se_resnext101_32x4d | imagenet | 46M |
SK-ResNe(X)t
编码器 | 权重 | 参数量(百万) |
---|---|---|
timm-skresnet18 | imagenet | 11M |
timm-skresnet34 | imagenet | 21M |
timm-skresnext50_32x4d | imagenet | 25M |
DenseNet
编码器 | 权重 | 参数量(百万) |
---|---|---|
densenet121 | imagenet | 6M |
densenet169 | imagenet | 12M |
densenet201 | imagenet | 18M |
densenet161 | imagenet | 26M |
Inception
编码器 | 权重 | 参数量(百万) |
---|---|---|
inceptionresnetv2 | imagenet / imagenet+background | 54M |
inceptionv4 | imagenet / imagenet+background | 41M |
xception | imagenet | 22M |
EfficientNet
编码器 | 权重 | 参数量(百万) |
---|---|---|
efficientnet-b0 | imagenet | 4M |
efficientnet-b1 | imagenet | 6M |
efficientnet-b2 | imagenet | 7M |
efficientnet-b3 | imagenet | 10M |
efficientnet-b4 | imagenet | 17M |
efficientnet-b5 | imagenet | 28M |
efficientnet-b6 | imagenet | 40M |
efficientnet-b7 | imagenet | 63M |
timm-efficientnet-b0 | imagenet / advprop / noisy-student | 4M |
timm-efficientnet-b1 | imagenet / advprop / noisy-student | 6M |
timm-efficientnet-b2 | imagenet / advprop / noisy-student | 7M |
timm-efficientnet-b3 | imagenet / advprop / noisy-student | 10M |
timm-efficientnet-b4 | imagenet / advprop / noisy-student | 17M |
timm-efficientnet-b5 | imagenet / advprop / noisy-student | 28M |
timm-efficientnet-b6 | imagenet / advprop / noisy-student | 40M |
timm-efficientnet-b7 | imagenet / advprop / noisy-student | 63M |
timm-efficientnet-b8 | imagenet / advprop | 84M |
timm-efficientnet-l2 | noisy-student | 474M |
timm-efficientnet-lite0 | imagenet | 4M |
timm-efficientnet-lite1 | imagenet | 5M |
timm-efficientnet-lite2 | imagenet | 6M |
timm-efficientnet-lite3 | imagenet | 8M |
timm-efficientnet-lite4 | imagenet | 13M |
MobileNet
DPN
编码器 | 权重 | 参数量(百万) |
---|---|---|
dpn68 | imagenet | 11M |
dpn68b | imagenet+5k | 11M |
dpn92 | imagenet+5k | 34M |
dpn98 | imagenet | 58M |
dpn107 | imagenet+5k | 84M |
dpn131 | imagenet | 76M |
VGG
编码器 | 权重 | 参数量(百万) |
---|---|---|
vgg11 | imagenet | 9M |
vgg11_bn | imagenet | 9M |
vgg13 | imagenet | 9M |
vgg13_bn | imagenet | 9M |
vgg16 | imagenet | 14M |
vgg16_bn | imagenet | 14M |
vgg19 | imagenet | 20M |
vgg19_bn | imagenet | 20M |
混合视觉Transformer
SegFormer在Imagenet上预训练的骨干网络!可以与包中的其他解码器一起使用,您可以将混合视觉Transformer与Unet、FPN等结合使用!
限制:
- 编码器不支持Linknet、Unet++
- 编码器仅支持FPN,且仅限编码器深度为5的情况
编码器 | 权重 | 参数量(百万) |
---|---|---|
mit_b0 | imagenet | 3M |
mit_b1 | imagenet | 13M |
mit_b2 | imagenet | 24M |
mit_b3 | imagenet | 44M |
mit_b4 | imagenet | 60M |
mit_b5 | imagenet | 81M |
MobileOne
苹果公司的"亚毫秒级"骨干网络在Imagenet上预训练!可以与所有解码器一起使用。
注意:在官方GitHub仓库中,s0变体有额外的num_conv_branches,导致比s1有更多的参数。
编码器 | 权重 | 参数量(百万) |
---|---|---|
mobileone_s0 | imagenet | 4.6M |
mobileone_s1 | imagenet | 4.0M |
mobileone_s2 | imagenet | 6.5M |
mobileone_s3 | imagenet | 8.8M |
mobileone_s4 | imagenet | 13.6M |
* ssl
、swsl
- 在ImageNet上进行半监督和弱监督学习(仓库)。
Timm编码器
Pytorch Image Models(简称timm)有许多预训练模型和接口,可以将这些模型用作smp中的编码器,但并非所有模型都受支持
- 并非所有transformer模型都实现了编码器所需的
features_only
功能 - 一些模型的步幅不适合
支持的编码器总数:549
🔁 模型API
model.encoder
- 预训练的骨干网络,用于提取不同空间分辨率的特征model.decoder
- 取决于模型架构(Unet
/Linknet
/PSPNet
/FPN
)model.segmentation_head
- 最后一个块,用于生成所需数量的掩码通道(还包括可选的上采样和激活)model.classification_head
- 可选块,在编码器顶部创建分类头model.forward(x)
- 顺序地将x
通过模型的编码器、解码器和分割头(如果指定,还包括分类头)
输入通道
输入通道参数允许您创建可处理任意数量通道张量的模型。如果您使用来自imagenet的预训练权重 - 第一个卷积的权重将被重用。对于1通道情况,它将是第一个卷积层权重的总和,否则通道将使用如下方式填充权重:new_weight[:, i] = pretrained_weight[:, i % 3]
,然后用new_weight * 3 / new_in_channels
进行缩放。
model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))
辅助分类输出
所有模型都支持aux_params
参数,默认设置为None
。如果aux_params = None
,则不创建分类辅助输出,否则模型不仅产生mask
,还产生形状为NC
的label
输出。分类头由GlobalPooling->Dropout(可选)->Linear->Activation(可选)层组成,可以通过aux_params
进行配置,如下所示:
aux_params=dict(
pooling='avg', # 'avg'或'max'之一
dropout=0.5, # dropout比率,默认为None
activation='sigmoid', # 激活函数,默认为None
classes=4, # 定义输出标签的数量
)
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
mask, label = model(x)
深度
深度参数指定编码器中的下采样操作数量,因此如果指定较小的depth
,可以使模型更轻量。
model = smp.Unet('resnet34', encoder_depth=4)
🛠 安装
PyPI 版本:
$ pip install segmentation-models-pytorch
从源代码安装最新版本:
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
🏆 使用本库获得的比赛胜利
Segmentation Models
包在图像分割比赛中被广泛使用。
这里你可以找到比赛名称、获胜者姓名以及他们解决方案的链接。
🤝 贡献
安装 SMP
make install_dev # 创建 .venv,以开发模式安装 SMP
运行测试和代码检查
make fixup # 使用 Ruff 进行格式化和代码检查
更新编码器表格
make table # 生成编码器表格并输出到标准输出
📝 引用
@misc{Iakubovskii:2019,
Author = {Pavel Iakubovskii},
Title = {Segmentation Models Pytorch},
Year = {2019},
Publisher = {GitHub},
Journal = {GitHub repository},
Howpublished = {\url{https://github.com/qubvel/segmentation_models.pytorch}}
}
🛡️ 许可证
项目在 MIT 许可证 下分发