Implementation, Testing and Comparison of Various Basis Functions of the KAN Model
The base-kan
project focuses on the implementation, testing, and comparison of various basis functions of the KAN model. The KAN model is a different neural network architecture from the MLP, as detailed in the paper "KAN: Kolmogorov-Arnold Networks". This repository aims to provide a comprehensive set of tools and examples to demonstrate the effectiveness of different basis functions within the KAN model.
To set up the project and install the necessary dependencies, follow these steps:
-
Clone the repository:
git clone https://github.com/Ivans-11/base-kan.git cd base-kan
-
Create a virtual environment:
python -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
-
Install the required packages:
pip install -r requirements.txt
Start with notebook
folder
b_kan_train.ipynb
,f_kan_train.ipynb
,g_kan_train.ipynb
,j_kan_train.ipynb
,r_kan_train.ipynb
,t_kan_train.ipynb
,w_kan_train.ipynb
,be_kan_train.ipynb
: These notebooks are specialized for training the BSplineKAN, FourierKAN, GaussianKAN, JacobiKAN, RationalKAN, TaylorKAN, WaveletKAN and BernsteinKAN models respectively, showcasing their specific configurations and training procedures.train&test_on_XX.ipynb
: These notebooks demonstrate how to train ,test and compare various KAN models and MLP using PyTorch, including examples on the XX datasets.series_predict_on_sunspots.ipynb
: This timing prediction example significantly demonstrates the advantages of the KAN model over the MLP
(referenced from pykan)
Kolmogorov-Arnold layer:
Kolmogorov-Arnold network:
The KAN model's underlying architecture code is located in the kans
folder. Their basis functions
- Base function:
$\phi(x) = \sum_{i=1}^{n} c_i B_{i,k}(x)$ - Learnable parameters: The coefficients of control points
$c_1 ,..., c_n$ - Configurable parameter: Grid count
$n$ , Order of the B-spline function$k$ , The control points are determined by grid count and grid range.
- Base function:
$\phi(x) = a_0 + \sum_{k=1}^{n} \left( a_k \cos(kx) + b_k \sin(kx) \right)$ - Learnable parameters: The coefficients of the Fourier series
$a_0, a_1, b_1 ,..., a_n, b_n$ - Configurable parameter: Frequency limit
$n$ - It is also possible to dynamically increase the frequency limit during training to increase the accuracy.
- Base function:
$\phi(x) = \sum_{i=1}^{n} a_i \exp\left(-\frac{(x - \mu_i)^2}{2 \sigma_i^2}\right)$ - Learnable parameters: The coefficients
$a_1,..., a_n$ - Configurable parameter: Grid count
$n$ , Parameters controlled by grid count and grid range$\mu_i,\sigma_i$
- Base function:
$\phi(x) = \sum_{k=0}^{n} c_k P_k^{(\alpha, \beta)}(x)$ - Learnable parameters: The coefficients of the Jacobi polynomials
$c_0 ,..., c_n$ - Configurable parameter: Maximum order
$n$ , Parameters of Jacobi polynomials$\alpha, \beta$
- Base function:
$\phi(x) = \frac{\sum_{i=0}^{m} a_i x^i}{1 + \lvert\sum_{j=1}^{n} b_j x^j\rvert}$ - Learnable parameters: The coefficients of the polynomials
$a_i, b_j$ - Configurable parameter: Order of the numerator
$m$ , Order of the denominator$n$
- Base function:
$\phi(x) = \sum_{k=0}^{n} c_k x^k$ - Learnable parameters: The coefficients of the polynomial
$c_0 ,..., c_n$ - Configurable parameter: Polynomial order
$n$
- Base function:
$\phi(x) = \sum_{i=1}^{n} a_i \psi\left(\frac{x - b_i}{s_i}\right)$ - Learnable parameters: Magnitude, scale and translation parameters
$a_i, b_i, s_i$ - Configurable parameter: Wave number
$n$ , Type of$\psi()$ including'mexican_hat'
,'morlet'
,'dog'
- Base function:
$\phi(x) = \sum_{k=0}^{n} c_k (x-a)^k (b-x)^{n-k}$ - Learnable parameters: The coefficients of the Bernstein polynomials
$c_0 ,..., c_n$ - Configurable parameter: Order of the polynomials
$n$ , Range of Interpolation$a, b$
- Based on the previous test results on various datasets, we found that: the comparison of the performance of various basis functions varies across different types of datasets.
For example, the TaylorKAN and RationalKAN models significantly outperform the other models in the wine dataset, but they perform poorly in the California Housing dataset; the WaveletKAN model has a significant advantage in the Iris dataset, but it does not perform well in the digits dataset, the wine dataset and the California Housing dataset.
In addition, JacobiKAN, FourierKAN, and GaussianKAN consistently perform well in the various datasets, and are consistently moderately good or even better.
- Based on the above conclusions, we propose a new modeling architecture: HB-KAN (Hybrid KAN).
Multiple basis functions are used separately for computation, and then the results are weighted and summed to obtain the final output. The weights are learnable parameters that can be automatically adjusted during the training process to assign higher weights to better choices.
For this purpose, we design two HB-KAN model architectures on different levels: HybridKAN by Layer and HybridKAN by Net
weighting at the layer level (K is the number of basis function species)
You can create it from a list of basis functions as the example.
weighting at the net level (K is the number of basis function species)
You can create it by using a list of basis functions as the example, or by adding a list of pre-trained KAN models as the example