项目介绍:pytorch-3dunet
pytorch-3dunet 是一个基于 PyTorch 的项目,专注于 3D U-Net 和其变体的实现。该项目为图像分割和回归任务提供了强大的工具支持,特别适用于医疗图像处理和生物学研究领域。
项目背景
3D U-Net 是一种扩展的卷积神经网络,特别适合处理三维图像数据。它最初设计为从稀疏的标注中学习密集的体积分割。pytorch-3dunet 提供了不同的 3D U-Net 实现,包括标准 3D U-Net、残差 3D U-Net,以及加入了“Squeeze and Excitation”模块的变体。
功能特色
-
支持的模型类型
UNet3D
:传统的标准 3D U-Net。ResidualUNet3D
:基于残差网络的 3D U-Net。ResidualUNetSE3D
:增强了“Squeeze and Excitation”功能的 3D U-Net。
-
训练任务
- 支持语义分割(包括二分类和多分类)。
- 支持回归问题,如去噪和去卷积学习。
-
2D U-Net 支持
- 2D U-Net 也是支持的,优化了性能以适应二维卷积层。
数据格式
输入数据应存储在 HDF5 文件中,通常包括 raw
和 label
两个数据集,用于存储输入数据和相应的真值标注。如果使用像素级交叉熵损失,还可以选择性提供 weight
数据集。
数据格式需根据是否为2D或3D,以及通道数来进行设置,确保兼容性。
安装
pytorch-3dunet 可以通过以下方式进行安装:
-
通过 conda/mamba 安装
- 快速安装并调用
train3dunet
和predict3dunet
命令。
- 快速安装并调用
-
从源代码安装
- 通过运行
python setup.py install
命令即可安装。
- 通过运行
使用指南
训练模型
通过指定配置文件,可以灵活地训练模型。用户需要提供训练数据和验证数据的路径。训练过程中可以通过 Tensorboard 监控进度。
预测
安装后,用户可以使用现有模型对新数据进行预测,只需提供模型路径和测试数据路径。可以通过配置来改善大数据集的预测效率。
多 GPU 支持
项目默认利用所有可用的 GPU 进行数据并行计算,从而优化计算效率。用户可以通过环境变量限制使用的 GPU 数量。
支持的损失函数
项目提供多种损失函数来支持不同的任务需求,包括:
- 对于语义分割:如二值交叉熵、Dice 损失等。
- 对于回归:如均方误差损失、L1损失等。
支持的评估指标
提供了详细的评价指标配置,用于语义分割(如 MeanIoU 和 Dice 系数)和回归(如 PSNR 和 MSE)。
示例项目
项目中提供了不同场景下的实际应用示例,包括对不同植物器官的边界和核的预测任务,并公开了预训练模型供使用和微调。
贡献与引用
如果用户想为本项目贡献,可以通过提供 Pull Request 的方式参与。同时,如果在研究中使用了此代码,请根据提供的格式进行引用。
pytorch-3dunet 项目通过其丰富的功能和易用性,成为了图像分割研究的重要工具,尤其在高分辨率医学体积图像的深度学习语义分割领域大放异彩。