Welcome to TALENT, a benchmark with a comprehensive machine learning toolbox designed to enhance model performance on tabular data. TALENT integrates advanced deep learning models, classical algorithms, and efficient hyperparameter tuning, offering robust preprocessing capabilities to optimize learning from tabular datasets. The toolbox is user-friendly and adaptable, catering to both novice and expert data scientists.
TALENT offers the following advantages:
- Diverse Methods: Includes various classical methods, tree-based methods, and the latest popular deep learning methods.
- Extensive Dataset Collection: Equipped with 300 datasets, covering a wide range of task types, size distributions, and dataset domains.
- Customizability: Easily allows the addition of datasets and methods.
- Versatile Support: Supports diverse normalization, encoding, and metrics.
If you use any content of this repo for your work, please cite the following bib entry:
TODO
TALENT integrates an extensive array of 20+ deep learning architectures for tabular data, including but not limited to:
- MLP: A multi-layer neural network, which is implemented according to RTDL.
- ResNet: A DNN that uses skip connections across many layers, which is implemented according to RTDL.
- SNN: An MLP-like architecture utilizing the SELU activation, which facilitates the training of deeper neural networks.
- DANets: A neural network designed to enhance tabular data processing by grouping correlated features and reducing computational complexity.
- TabCaps: A capsule network that encapsulates all feature values of a record into vectorial features.
- DCNv2: Consists of an MLP-like module combined with a feature crossing module, which includes both linear layers and multiplications.
- NODE: A tree-mimic method that generalizes oblivious decision trees, combining gradient-based optimization with hierarchical representation learning.
- GrowNet: A gradient boosting framework that uses shallow neural networks as weak learners.
- TabNet: A tree-mimic method using sequential attention for feature selection, offering interpretability and self-supervised learning capabilities.
- TabR: A deep learning model that integrates a KNN component to enhance tabular data predictions through an efficient attention-like mechanism.
- ModernNCA: A deep tabular model inspired by traditional Neighbor Component Analysis, which makes predictions based on the relationships with neighbors in a learned embedding space.
- DNNR: Enhances KNN by using local gradients and Taylor approximations for more accurate and interpretable predictions.
- AutoInt: A token-based method that uses a multi-head self-attentive neural network to automatically learn high-order feature interactions.
- Saint: A token-based method that leverages row and column attention mechanisms for tabular data.
- TabTransformer: A token-based method that enhances tabular data modeling by transforming categorical features into contextual embeddings.
- FT-Transformer: A token-based method which transforms features to embeddings and applies a series of attention-based transformations to the embeddings.
- TANGOS: A regularization-based method for tabular data that uses gradient attributions to encourage neuron specialization and orthogonalization.
- SwitchTab: A regularization-based method tailored for tabular data that improves representation learning through an asymmetric encoder-decoder framework.
- PTaRL: A regularization-based framework that enhances prediction by constructing and projecting into a prototype-based space.
- TabPFN: A general model which involves the use of pre-trained deep neural networks that can be directly applied to any tabular task.
- HyperFast: A meta-trained hypernetwork that generates task-specific neural networks for instant classification of tabular data.
- TabPTM: A general method for tabular data that standardizes heterogeneous datasets using meta-representations, allowing a pre-trained model to generalize to unseen datasets without additional training.
Clone this GitHub repository:
git clone https://github.com/qile2000/LAMDA-TALENT
cd LAMDA-TALENT/TabBench
-
Edit the
[MODEL_NAME].json
file for global settings and hyperparameters. -
Run:
python train_model_deep.py --model_type MODEL_NAME
for deep methods, or:
python train_model_classical.py --model_type MODEL_NAME
for classical methods.
TODO
Datasets are available at Google Drive.
TODO
We thank the following repos for providing helpful components/functions in our work:
- Rtdl-revisiting-models
- Rtdl-num-embeddings
- Tabular-dl-tabr
- DANet
- TabCaps
- DNNR
- PTaRL
- Saint
- SwitchTab
- TabNet
- TabPFN
- Tabtransformer-pytorch
- TANGOS
- GrowNet
- HyperFast
We provide comprehensive evaluations of classical and deep tabular methods based on our toolbox in a fair manner in the Figure. Three tabular prediction tasks, namely, binary classification, multi-class classification, and regression, are considered, and each subfigure represents a different task type.
We use accuracy and RMSE as the metrics for classification and regression, respectively. To calibrate the metrics, we choose the average performance rank to compare all methods, where a lower rank indicates better performance, following [Sheskin (2003)](Handbook of Parametric and Nonparametric Statistical Procedures | Thir (taylorfrancis.com)). Efficiency is calculated by the average training time in seconds, with lower values denoting better time efficiency. The model size is visually indicated by the radius of the circles, offering a quick glance at the trade-off between model complexity and performance.
From the comparison, we observe that CatBoost achieves the best average rank in most classification and regression tasks. Among all deep tabular methods, ModernNCA performs the best in most cases while maintaining an acceptable training cost. These results highlight the effectiveness of CatBoost and ModernNCA in handling various tabular prediction tasks, making them suitable choices for practitioners seeking high performance and efficiency.
These visualizations serve as an effective tool for quickly and fairly assessing the strengths and weaknesses of various tabular methods across different task types, enabling researchers and practitioners to make informed decisions when selecting suitable modeling techniques for their specific needs.
If there are any questions, please feel free to propose new features by opening an issue or contact the author: Siyang Liu ([email protected]) and Haorun Cai ([email protected]) and Qile Zhou ([email protected]) and Han-Jia Ye ([email protected]). Enjoy the code.