TorchConv KAN:卷积Kolmogorov-Arnold网络集合
本项目介绍并展示了使用PyTorch和CUDA加速来训练、验证和量化卷积KAN模型。torch-conv-kan
在MNIST、CIFAR、TinyImagenet和Imagenet1k数据集上评估性能。
项目状态:开发中
更新
-
✅ [2024/05/13] 卷积KALN层已可用
-
✅ [2024/05/14] 卷积KAN和快速KAN层已可用
-
✅ [2024/05/15] 卷积ChebyKAN现已可用。添加了MNIST、CIFAR10和CIFAR100基准测试。
-
✅ [2024/05/19] 发布了类ResNet、类U-net和基于MoE的模型(别问为什么=)),以及基于accelerate的训练代码。
-
✅ [2024/05/21] 发布了类VGG和类DenseNet模型!添加了Gram KAN卷积层。
-
✅ [2024/05/23] 添加了WavKAN卷积层。修复了
trainer.py
中输出钩子的bug。 -
✅ [2024/05/25] 添加了类U2-net模型。修复了
trainer.py
中的内存泄漏问题。 -
✅ [2024/05/27] 更新了WavKAN的实现 - 现在速度更快。添加了VGG-WavKAN。
-
✅ [2024/05/31] 修复了KACN Conv不稳定问题,添加了Lion优化器,更新了基线模型和基准测试,以及:fire::fire::fire:发布了在Imagenet1k上预训练的权重:fire::fire::fire:,还有Imagenet1k训练脚本。
-
✅ [2024/06/03] JacobiKAN卷积现已可用。
-
✅ [2024/06/05] BernsteinKAN和BernsteinKAN卷积现已可用。
-
✅ [2024/06/15] 引入瓶颈KAN卷积(目前使用Gram多项式作为基函数)。添加了LBFGS优化器支持(尚未充分测试,如果遇到任何问题,请提出问题)。发布了CIFAR 100上的正则化基准测试。发布了使用Ray Tune进行超参数调优。
-
✅ [2024/06/18] ReLU KAN卷积现已可用。
-
✅ [2024/06/20] :fire::fire::fire:发布了在Imagenet1k上的新预训练检查点:fire::fire::fire: VGG11风格的瓶颈Gram卷积。该模型在Imagenet1k验证集上达到了68.5%的Top1准确率,仅使用7.25M参数。
-
✅ [2024/07/02] :fire::fire::fire: 我们发布了我们的论文 :fire::fire::fire: Kolmogorov-Arnold卷积:设计原则和实证研究
-
✅ [2024/07/09] 发布了KAGN模型的PEFT代码,以及新的类RDNet模型、改进的Kolmogorov-Arnold-Gram卷积实现和医学图像分割脚本。
TODO列表和下一步计划
- 目前正在Imagenet1k上训练类VGG19模型
- 目前正在Imagenet1k上训练类Resnet50模型
- 正在进行其他基准测试的微调实验,以及PEFT方法的探索
- 我也在研究剪枝和可视化方法
目录:
介绍卷积KAN层
Kolmogorov-Arnold网络基于Kolmogorov-Arnold表示定理:
根据这个公式,KAN: Kolmogorov-Arnold网络的作者推导出了新的架构:边上的可学习激活和节点上的求和。相反,MLP在节点上执行固定的非线性,在边上执行可学习的线性投影。
在卷积层中,滤波器或核在2D输入数据上"滑动",执行元素级乘法。结果被求和为单个输出像素。核对其滑过的每个位置执行相同的操作,将2D(1D或3D)特征矩阵转换为不同的矩阵。虽然1D和3D卷积共享相同的概念,但它们具有不同的滤波器、输入数据和输出数据维度。然而,为简单起见,我们将重点关注2D。
通常,在卷积层之后,会应用归一化层(如BatchNorm、InstanceNorm等)和非线性激活(ReLU、LeakyReLU、SiLU等)。
更正式地说:假设我们有一个大小为N x N的输入图像y。为简单起见,我们省略通道轴,它会增加另一个求和符号。首先,我们需要用大小为m x m的核W进行卷积:
然后,我们应用批量归一化和非线性,例如ReLU:
Kolmogorov-Arnold卷积的工作方式略有不同:核由一组单变量非线性函数组成。这个核在2D输入数据上"滑动",对核的函数进行元素级应用。然后将结果求和为单个输出像素。更正式地说:假设我们有一个大小为N x N的输入图像y(再次)。为简单起见,我们省略通道轴,它会增加另一个求和符号。因此,基于KAN的卷积定义为:
每个phi是一个单变量非线性可学习函数。在原始论文中,作者建议使用这种形式的函数:
作者建议选择SiLU作为b(x)激活:
总之,"传统"卷积是一个权重矩阵,而Kolmogorov-Arnold卷积是一组函数。这是主要区别。这里的关键问题是 - 我们应该如何构造这些单变量非线性函数?答案与KAN相同:B样条、多项式、RBF、小波等。
在这个仓库中,展示了以下层的实现:
-
KANConv1DLayer
、KANConv2DLayer
、KANConv3DLayer
类代表基于Kolmogorov Arnold网络的卷积层,在[1]中引入。基线模型在models/baselines/conv_kan_baseline.py
中实现。 -
KALNConv1DLayer
、KALNConv2DLayer
、KALNConv3DLayer
类代表基于Kolmogorov Arnold勒让德网络的卷积层,在[2]中引入。基线模型在models/baselines/conv_kaln_baseline.py
中实现。 -
FastKANConv1DLayer
、FastKANConv2DLayer
、FastKANConv3DLayer
类代表基于快速Kolmogorov Arnold网络的卷积层,在[3]中引入。基线模型在models/baselines/fast_conv_kan_baseline.py
中实现。 -
KACNConv1DLayer
、KACNConv1DLayer
、KACNConv1DLayer
类代表基于Kolmogorov Arnold网络的卷积层,使用切比雪夫多项式代替B样条,在[4]中引入。基线模型在models/baselines/conv_kacn_baseline.py
中实现。 -
KAGNConv1DLayer
、KAGNConv1DLayer
、KAGNConv1DLayer
类代表基于Kolmogorov Arnold网络的卷积层,使用Gram多项式代替B样条,在[5]中引入。基线模型在models/baselines/conv_kagn_baseline.py
中实现。 -
WavKANConv1DLayer
、WavKANConv1DLayer
、WavKANConv1DLayer
类代表基于小波Kolmogorov Arnold网络的卷积层,在[6]中引入。基线模型在models/baselines/conv_wavkan_baseline.py
中实现。 -
KAJNConv1DLayer
、KAJNConv2DLayer
、KAJNConv3DLayer
类代表基于Jacobi Kolmogorov Arnold网络的卷积层,在[7]中引入并做了minor修改。 -
我们引入了
KABNConv1DLayer
、KABNConv2DLayer
、KABNConv3DLayer
类,代表基于Bernstein Kolmogorov Arnold网络的卷积层。 -
KABNConv1DLayer
、KABNConv2DLayer
、KABNConv3DLayer
类代表基于ReLU KAN的卷积层,在[8]中引入。
引入瓶颈卷积KAN层
如我们之前讨论的,phi函数由两个块组成:残差激活函数(下图左侧)和可学习非线性(样条、多项式、小波等;下图右侧)。
主要问题在右侧部分:输入数据的通道数越多,模型中引入的可学习参数就越多。因此,像ResNet中的瓶颈层一样,我们可以使用一个简单的技巧:对输入数据应用1x1的压缩卷积,在这个空间中执行样条,然后应用1x1的解压缩卷积。
假设我们有512个通道的输入x,想要执行512个过滤器的ConvKAN。首先,1x1卷积将x投影到128个通道的y。然后我们对y应用学习的非线性,最后1x1卷积将y转换回512个通道的t。现在我们可以将t与残差激活相加。
本仓库中实现了以下瓶颈层:
-
BottleNeckKAGNConv1DLayer
、BottleNeckKAGNConv2DLayer
、BottleNeckKAGNConv3DLayer
类代表基于Kolmogorov Arnold网络的瓶颈卷积层,使用Gram多项式代替B样条。 -
BottleNeckKAGNConv1DLayer
、BottleNeckKAGNConv2DLayer
、BottleNeckKAGNConv3DLayer
类代表基于Kolmogorov Arnold网络的瓶颈卷积层,使用Gram多项式代替B样条。
模型库
ResKANets
我们引入了ResKANets - 一种ResNet类似的模型,用KAN卷积代替常规卷积。主类ResKANet
可以在models/densekanet.py
中找到。我们的实现支持具有KAN、Fast KAN、KALN、KAGN和KACN卷积层的块。
在CIFAR10上训练75个epoch后,具有Kolmogorov Arnold Legendre卷积的ResKANet 18达到了84.17%的准确率和0.985的AUC(OVO)。
在Tiny Imagenet上训练75个epoch后,具有Kolmogorov Arnold Legendre卷积的ResKANet 18达到了28.62%的准确率,55.49%的top-5准确率,以及0.932的AUC(OVO)。
请注意,这些是初步结果,目前正在进行更多实验。
DenseKANets
我们引入了DenseKANets - 一种DenseNet类似的模型,用KAN卷积代替常规卷积。主类DenseKANet
可以在models/reskanet.py
中找到。我们的实现支持具有KAN、Fast KAN、KALN、KAGN和KACN卷积层的块。
在Tiny Imagenet上训练250个epoch后,具有Kolmogorov Arnold Gram卷积的DenseNet 121达到了40.61%的准确率,65.08%的top-5准确率,以及0.957的AUC(OVO)。
请注意,这些是初步结果,目前正在进行更多实验。
VGGKAN
我们引入了VGGKANs - 一种VGG类似的模型,用KAN卷积代替常规卷积。主类VGG
可以在models/vggkan.py
中找到。该模型支持所有类型的KANs卷积层。
在Imagenet1k上预训练的检查点:
模型 | 准确率, top1 | 准确率, top5 | AUC (ovo) | AUC (ovr) |
---|---|---|---|---|
VGG KAGN 11v2 | 59.1 | 82.29 | 99.43 | 99.43 |
VGG KAGN 11v4 | 61.17 | 83.26 | 99.42 | 99.43 |
VGG KAGN BN 11v4 | 68.5 | 88.46 | 99.61 | 99.61 |
更多检查点即将推出,敬请期待。我可用的计算资源相当有限,所以训练和评估所有模型需要一些时间。
UKANet和U2KANet
我们引入了UKANets和U2KANets - 一种U-net类似的模型,用KAN卷积代替常规卷积,基于resnet块,以及用KAN卷积代替常规卷积的U2-net。主类UKANet
可以在models/ukanet.py
中找到。我们的实现支持具有KAN、Fast KAN、KALN、KAGC和KACN卷积层的基本和瓶颈块。
性能指标
MNIST和CIFAR10/100上的基线模型 总结:8层SimpleKAGNConv在MNIST上达到99.68的准确率,在CIFAR 10上达到84.32,在CIFAR100上达到59.27。除了CIFAR10外,它在所有数据集上都是最好的模型:8层SimpleWavKANConv在CIFAR10上达到85.37的准确率。
正则化、缩放和超参数优化研究 总结:使用最优参数集,8层SimpleKAGNConv在CIFAR100上达到74.87%的准确率。
讨论
首先需要指出的是,获得的结果是初步的。模型架构还没有被彻底探索,只代表了许多可能设计变体中的两个。
尽管如此,实验表明Kolmogorov-Arnold卷积网络在MNIST数据集上的性能优于经典卷积架构,但在CIFAR-10和CIFAR-100上的质量显著不如经典卷积。基于ChebyKAN的卷积在训练过程中遇到稳定性问题,需要进一步调查。
作为下一步,我计划为KAN卷积寻找一种合适的架构,以在CIFAR-10/100上达到可接受的质量,并尝试将这些模型扩展到更复杂的数据集。
先决条件
确保您的系统上安装了以下内容:
- Python(3.9或更高版本)
- CUDA工具包(与您安装的PyTorch的CUDA版本相对应)
- cuDNN(与您安装的CUDA工具包兼容)
使用方法
以下是基于KAN卷积的简单模型示例:
import torch
import torch.nn as nn
from kan_convs import KANConv2DLayer
class SimpleConvKAN(nn.Module):
def __init__(
self,
layer_sizes,
num_classes: int = 10,
input_channels: int = 1,
spline_order: int = 3,
groups: int = 1):
super(SimpleConvKAN, self).__init__()
self.layers = nn.Sequential(
KANConv2DLayer(input_channels, layer_sizes[0], spline_order, kernel_size=3, groups=1, padding=1, stride=1,
dilation=1),
KANConv2DLayer(layer_sizes[0], layer_sizes[1], spline_order, kernel_size=3, groups=groups, padding=1,
stride=2, dilation=1),
KANConv2DLayer(layer_sizes[1], layer_sizes[2], spline_order, kernel_size=3, groups=groups, padding=1,
stride=2, dilation=1),
KANConv2DLayer(layer_sizes[2], layer_sizes[3], spline_order, kernel_size=3, groups=groups, padding=1,
stride=1, dilation=1),
nn.AdaptiveAvgPool2d((1, 1))
)
self.output = nn.Linear(layer_sizes[3], num_classes)
self.drop = nn.Dropout(p=0.25)
def forward(self, x):
x = self.layers(x)
x = torch.flatten(x, 1)
x = self.drop(x)
x = self.output(x)
return x
要在MNIST、CIFAR-10和CIFAR-100数据集上运行基线模型的训练和测试,执行以下代码行:
python mnist_conv.py
此脚本将在MNIST、CIFAR10或CIFAR100上训练基线模型,验证它们,量化并记录性能指标。
基于Accelerate的训练
我们引入了使用Accelerate、Hydra配置和Wandb日志记录的训练代码。
1. 克隆仓库
克隆 torch-conv-kan
仓库并设置项目环境:
git clone https://github.com/IvanDrokin/torch-conv-kan.git
cd torch-conv-kan
pip install -r requirements.txt
2. 配置 Weights & Biases (wandb)
要使用 wandb 监控实验和模型性能:
- 设置 wandb 账户:
- 在 Weights & Biases 注册或登录。
- 在账户设置中找到你的 API 密钥。
- 在你的项目中初始化 wandb:
在运行训练脚本之前,初始化 wandb:
wandb login
根据提示输入你的 API 密钥,将脚本执行与你的 wandb 账户关联。
- 将
configs/cifar10-reskanet.yaml
或configs/tiny-imagenet-reskanet.yaml
中的实体名称调整为你的用户名或团队名称
运行
更新配置中的任何参数并运行
accelerate launch cifar.py
此脚本在 CIFAR10 数据集上训练模型、验证模型,并使用 wandb 记录性能指标。
accelerate launch tiny_imagenet.py
此脚本在 Tiny Imagenet 数据集上训练模型、验证模型,并使用 wandb 记录性能指标。
使用自己的数据集或模型
如果你想使用自己的数据集,请按照以下步骤操作:
- 复制
tiny_imagenet.py
并修改get_data()
方法。如果基本的分类数据集实现不适合你的数据,请升级它或编写自己的实现。 - 如有必要,将
model = reskalnet_18x64p(...)
替换为你自己的模型。 - 在
config
文件夹中创建配置 yaml 文件,遵循提供的模板。 - 运行
accelerate launch your_script.py
引用本项目
如果你在研究中使用了本项目或想引用基准结果,请使用以下 BibTeX 条目。
@misc{drokin2024kolmogorovarnoldconvolutionsdesignprinciples,
title={Kolmogorov-Arnold Convolutions: Design Principles and Empirical Studies},
author={Ivan Drokin},
year={2024},
eprint={2407.01092},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2407.01092},
}
贡献
欢迎贡献。如有必要,请提出问题。
致谢
本仓库基于 TorchKAN、FastKAN、ChebyKAN、GRAMKAN、WavKAN、JacobiKAN 和 ReLU KAN。我们在此感谢他们的开放研究和探索。
参考文献
- [1] Ziming Liu 等,"KAN: Kolmogorov-Arnold Networks",2024,arXiv。https://arxiv.org/abs/2404.19756
- [2] https://github.com/1ssb/torchkan
- [3] https://github.com/ZiyaoLi/fast-kan
- [4] https://github.com/SynodicMonth/ChebyKAN
- [5] https://github.com/Khochawongwat/GRAMKAN
- [6] https://github.com/zavareh1/Wav-KAN
- [7] https://github.com/SpaceLearner/JacobiKAN
- [8] https://github.com/quiqi/relu_kan
- [9] https://github.com/KindXiaoming/pykan
- [10] https://github.com/Blealtan/efficient-kan