Kolmogorov-Arnold网络的高效实现
本仓库包含了Kolmogorov-Arnold网络(KAN)的高效实现。KAN的原始实现可在此处获取。
原始实现的性能问题主要是因为它需要展开所有中间变量以执行不同的激活函数。对于一个具有in_features
输入和out_features
输出的层,原始实现需要将输入展开为形状为(batch_size, out_features, in_features)
的张量来执行激活函数。然而,所有激活函数都是固定的B样条基函数集的线性组合;鉴于此,我们可以将计算重新表述为用不同的基函数激活输入,然后线性组合它们。这种重新表述可以显著减少内存消耗,使计算成为简单的矩阵乘法,并自然地适用于前向和反向传播。
问题在于稀疏化,这被认为对KAN的可解释性至关重要。作者提出了一种基于输入样本的L1正则化,这需要对(batch_size, out_features, in_features)
张量进行非线性操作,因此与重新表述不兼容。我改为使用权重的L1正则化来替代,这在神经网络中更为常见,并且与重新表述兼容。作者的实现实际上也包含了这种正则化,与论文中描述的方法并存,所以我认为这可能会有帮助。需要更多实验来验证这一点;但至少原始方法在追求效率的情况下是不可行的。
另一个区别是,除了可学习的激活函数(B样条)外,原始实现还包括每个激活函数的可学习缩放。我提供了一个默认为True
的enable_standalone_scale_spline
选项来包含这个特性;禁用它会使模型更高效,但可能会影响结果。这需要更多的实验。
2024年5月4日更新:@xiaol提示base_weight
参数的常数初始化可能在MNIST上存在问题。目前我已将base_weight
和spline_scaler
矩阵的初始化方式改为kaiming_uniform_
,遵循nn.Linear
的初始化方式。这似乎在MNIST上效果好得多(从约20%提高到约97%),但我不确定这是否普遍适用。