Skip to content
/ kat Public
forked from Adamdad/kat

Kolmogorov-Arnold Transformer: A PyTorch Implementation with CUDA kernel

License

Notifications You must be signed in to change notification settings

Junkai03/kat

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Kolmogorov–Arnold Transformer: A PyTorch Implementation


This is a PyTorch/GPU implementation of the paper Kolmogorov–Arnold Transformer (KAT), which replace the MLP layers in vision transformer with KAN layers.

Kolmogorov–Arnold Transformer

📝[Paper] </>[code]

Xingyi Yang, Xinchao Wang

National University of Singapore

Key Insight:

The KAT model integrates KANs into transformers for large-scale training scenarios such as ImageNet, showing significant performance improvements.

Installation and Dataset

# install torch and other things
pip install timm==1.0.3
git clone https://github.com/Adamdad/rational_kat_cu.git
cd rational_kat_cu
pip install -e .

please refer to https://github.com/Adamdad/rational_kat_cu.git for the cuda rational function installation

Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this script

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Model Checkpoints

Download pre-trained models or access training checkpoints:

Model Setup Param Top1 Link
KAT-T From Scratch 5.7M 74.6 link
KAT-T From ViT 5.7M 75.7 link
KAT-S From Scratch 22.1M 81.2 link
KAT-S From ViT 22.1M 82.0 link
KAT-B From Scratch 86.6M 82.3 link
KAT-B From ViT 86.6M 82.8

Model Training

bash scripts/train_kat_tiny_8x128.sh

If you want to change the hyper-parameters, can edit

#!/bin/bash
DATA_PATH=/local_home/dataset/imagenet/

bash ./dist_train.sh 8 $DATA_PATH \
    --model kat_tiny_swish_patch16_224 \ # Rationals are initialized to be swish functions 
    -b 128 \
    --opt adamw \
    --lr 1e-3 \
    --weight-decay 0.05 \
    --epochs 300 \
    --mixup 0.8 \
    --cutmix 1.0 \
    --sched cosine \
    --smoothing 0.1 \
    --drop-path 0.1 \
    --aa rand-m9-mstd0.5 \
    --remode pixel --reprob 0.25 \
    --amp \
    --crop-pct 0.875 \
    --mean 0.485 0.456 0.406 \
    --std 0.229 0.224 0.225 \
    --model-ema \
    --model-ema-decay 0.9999 \
    --output output/kat_tiny_swish_patch16_224 \
    --log-wandb

Acknowledgments

We extend our gratitude to the authors of rational_activations for their contributions to CUDA rational function implementations that inspired parts of this work.

Bibtex

@misc{yang2024compositional,
    title={Kolmogorov–Arnold Transformer},
    author={Xingyi Yang and Xinchao Wang},
    year={2024},
    eprint={XXXX},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

About

Kolmogorov-Arnold Transformer: A PyTorch Implementation with CUDA kernel

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 76.7%
  • Python 22.6%
  • Shell 0.7%