🎉 This is a PyTorch/GPU implementation of the paper Kolmogorov–Arnold Transformer (KAT), which replace the MLP layers in vision transformer with KAN layers.
Kolmogorov–Arnold Transformer
Xingyi Yang, Xinchao Wang
National University of Singapore
The KAT model integrates KANs into transformers for large-scale training scenarios such as ImageNet, showing significant performance improvements.
# install torch and other things
pip install timm==1.0.3
git clone https://github.com/Adamdad/rational_kat_cu.git
cd rational_kat_cu
pip install -e .
please refer to https://github.com/Adamdad/rational_kat_cu.git
for the cuda rational function installation
📦 Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this script
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
Refer to example.py
for a detailed use case demonstrating how to use KAT with timm to classify an image.
Download pre-trained models or access training checkpoints:
🏷️ Model | ⚙️ Setup | 📦 Param | 📈 Top1 | 🔗 Link |
---|---|---|---|---|
KAT-T | From Scratch | 5.7M | 74.6 | link/huggingface |
KAT-T | From ViT | 5.7M | 75.7 | link/huggingface |
KAT-S | From Scratch | 22.1M | 81.2 | link/huggingface |
KAT-S | From ViT | 22.1M | 82.0 | link/huggingface |
KAT-B | From Scratch | 86.6M | 82.3 | link/huggingface |
KAT-B | From ViT | 86.6M | 82.8 | link/huggingface |
bash scripts/train_kat_tiny_8x128.sh
If you want to change the hyper-parameters, can edit
#!/bin/bash
DATA_PATH=/local_home/dataset/imagenet/
bash ./dist_train.sh 8 $DATA_PATH \
--model kat_tiny_swish_patch16_224 \ # Rationals are initialized to be swish functions
-b 128 \
--opt adamw \
--lr 1e-3 \
--weight-decay 0.05 \
--epochs 300 \
--mixup 0.8 \
--cutmix 1.0 \
--sched cosine \
--smoothing 0.1 \
--drop-path 0.1 \
--aa rand-m9-mstd0.5 \
--remode pixel --reprob 0.25 \
--amp \
--crop-pct 0.875 \
--mean 0.485 0.456 0.406 \
--std 0.229 0.224 0.225 \
--model-ema \
--model-ema-decay 0.9999 \
--output output/kat_tiny_swish_patch16_224 \
--log-wandb
To evaluate our kat_tiny_patch16_224
models, run:
DATA_PATH=/local_home/dataset/imagenet/
CHECKPOINT_PATH=kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth
python validate.py $DATA_PATH --model kat_tiny_patch16_224 \
--checkpoint $CHECKPOINT_PATH -b 512
###################
Validating in float32. AMP not enabled.
Loaded state_dict from checkpoint 'kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth'
Model kat_tiny_patch16_224 created, param count: 5718328
Data processing configuration for current model + dataset:
input_size: (3, 224, 224)
interpolation: bicubic
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)
crop_pct: 0.875
crop_mode: center
Test: [ 0/98] Time: 3.453s (3.453s, 148.28/s) Loss: 0.6989 (0.6989) Acc@1: 84.375 ( 84.375) Acc@5: 96.875 ( 96.875)
.......
Test: [ 90/98] Time: 0.212s (0.592s, 864.23/s) Loss: 1.1640 (1.1143) Acc@1: 71.875 ( 74.270) Acc@5: 93.750 ( 92.220)
* Acc@1 74.558 (25.442) Acc@5 92.390 (7.610)
--result
{
"model": "kat_tiny_patch16_224",
"top1": 74.558,
"top1_err": 25.442,
"top5": 92.39,
"top5_err": 7.61,
"param_count": 5.72,
"img_size": 224,
"crop_pct": 0.875,
"interpolation": "bicubic"
}
We extend our gratitude to the authors of rational_activations for their contributions to CUDA rational function implementations that inspired parts of this work.
If you use this repository, please cite:
@misc{yang2024compositional,
title={Kolmogorov–Arnold Transformer},
author={Xingyi Yang and Xinchao Wang},
year={2024},
eprint={XXXX},
archivePrefix={arXiv},
primaryClass={cs.CV}
}