MedSegDiff - Pytorch
在Pytorch中实现MedSegDiff - 百度推出的最先进医学分割方法,使用DDPM和在特征级别上的增强条件,并在傅里叶空间中进行特征过滤。
致谢
-
感谢StabilityAI的慷慨赞助,以及其他所有赞助者
安装
$ pip install med-seg-diff-pytorch
使用方法
import torch
from med_seg_diff_pytorch import Unet, MedSegDiff
model = Unet(
dim = 64,
image_size = 128,
mask_channels = 1, # 分割有1个通道
input_img_channels = 3, # 输入图像有3个通道
dim_mults = (1, 2, 4, 8)
)
diffusion = MedSegDiff(
model,
timesteps = 1000
).cuda()
segmented_imgs = torch.rand(8, 1, 128, 128) # 输入归一化为0到1
input_imgs = torch.rand(8, 3, 128, 128)
loss = diffusion(segmented_imgs, input_imgs)
loss.backward()
# 经过大量训练后
pred = diffusion.sample(input_imgs) # 传入未分割的图像
pred.shape # 预测的分割图像 - (8, 3, 128, 128)
训练
运行命令
accelerate launch driver.py --mask_channels=1 --input_img_channels=3 --image_size=64 --data_path='./data' --dim=64 --epochs=100 --batch_size=1 --scale_lr --gradient_accumulation_steps=4
如果你想添加自条件(使用目前为止的掩码进行条件设置),请添加 --self_condition
待办事项
- 一些基本的训练代码,Trainer接收针对医学图像格式定制的自定义数据集 - 感谢@isamu-isozaki
- 在中间添加任意深度的完整transformer,如simple diffusion中所做的那样
引用
@article{Wu2022MedSegDiffMI,
title = {MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model},
author = {Junde Wu and Huihui Fang and Yu Zhang and Yehui Yang and Yanwu Xu},
journal = {ArXiv},
year = {2022},
volume = {abs/2211.00611}
}
@inproceedings{Hoogeboom2023simpleDE,
title = {simple diffusion: End-to-end diffusion for high resolution images},
author = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
year = {2023}
}