Skip to content

Commit

Permalink
fix configs
Browse files Browse the repository at this point in the history
  • Loading branch information
6sy666 committed Jun 16, 2024
1 parent ddf05d3 commit 94a4fe5
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 39 deletions.
19 changes: 19 additions & 0 deletions TabBench/configs/classical_configs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"dataset": "Bank_Customer_Churn_Dataset",
"model_type": "xgboost",
"normalization": "standard",
"num_nan_policy": "mean",
"cat_nan_policy": "new",
"cat_policy": "ordinal",
"num_policy": "none",
"n_bins": 2,
"cat_min_frequency": 0.0,
"n_trials": 100,
"seed_num": 15,
"gpu": "0",
"tune": false,
"retune": false,
"dataset_path": "data",
"model_path": "results_model",
"evaluate_option": "best-val"
}
22 changes: 22 additions & 0 deletions TabBench/configs/deep_configs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"dataset": "Bank_Customer_Churn_Dataset",
"model_type": "mlp",
"max_epoch": 200,
"batch_size": 1024,
"normalization": "standard",
"num_nan_policy": "mean",
"cat_nan_policy": "new",
"cat_policy": "ordinal",
"num_policy": "none",
"n_bins": 2,
"cat_min_frequency": 0.0,
"n_trials": 100,
"seed_num": 15,
"workers": 0,
"gpu": "0",
"tune": false,
"retune": false,
"evaluate_option": "best-val",
"dataset_path": "data",
"model_path": "results_model"
}
82 changes: 43 additions & 39 deletions TabBench/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,35 +183,36 @@ def get_classical_args():
import argparse
import warnings
warnings.filterwarnings("ignore")

with open('configs/classical_configs.json','r') as file:
default_args = json.load(file)
parser = argparse.ArgumentParser()
# basic parameters
parser.add_argument('--dataset', type=str, default='Bank_Customer_Churn_Dataset')
parser.add_argument('--dataset', type=str, default=default_args['dataset'])
parser.add_argument('--model_type', type=str,
default='xgboost',
default=default_args['model_type'],
choices=['LogReg', 'NCM', 'RandomForest',
'xgboost', 'catboost', 'lightgbm',
'svm','knn', 'NaiveBayes',"dummy","LinearRegression"
])

# optimization parameters
parser.add_argument('--normalization', type=str, default='standard', choices=['none', 'standard', 'minmax', 'quantile', 'maxabs', 'power', 'robust'])
parser.add_argument('--num_nan_policy', type=str, default='mean', choices=['mean', 'median'])
parser.add_argument('--cat_nan_policy', type=str, default='new', choices=['new', 'most_frequent'])
parser.add_argument('--cat_policy', type=str, default='ohe', choices=['indices', 'ordinal', 'ohe', 'binary', 'hash', 'loo', 'target', 'catboost'])
parser.add_argument('--num_policy',type=str, default='none',choices=['none','Q_PLE','T_PLE','Q_Unary','T_Unary','Q_bins','T_bins','Q_Johnson','T_Johnson'])
parser.add_argument('--n_bins', type=int, default=2)
parser.add_argument('--cat_min_frequency', type=float, default=0.0)
parser.add_argument('--normalization', type=str, default=default_args['normalization'], choices=['none', 'standard', 'minmax', 'quantile', 'maxabs', 'power', 'robust'])
parser.add_argument('--num_nan_policy', type=str, default=default_args['num_nan_policy'], choices=['mean', 'median'])
parser.add_argument('--cat_nan_policy', type=str, default=default_args['cat_nan_policy'], choices=['new', 'most_frequent'])
parser.add_argument('--cat_policy', type=str, default=default_args['cat_policy'], choices=['indices', 'ordinal', 'ohe', 'binary', 'hash', 'loo', 'target', 'catboost'])
parser.add_argument('--num_policy',type=str, default=default_args['num_policy'],choices=['none','Q_PLE','T_PLE','Q_Unary','T_Unary','Q_bins','T_bins','Q_Johnson','T_Johnson'])
parser.add_argument('--n_bins', type=int, default=default_args['n_bins'])
parser.add_argument('--cat_min_frequency', type=float, default=default_args['cat_min_frequency'])

# other choices
parser.add_argument('--n_trials', type=int, default=50)
parser.add_argument('--seed_num', type=int, default=10)
parser.add_argument('--gpu', default='0')
parser.add_argument('--tune', action='store_true', default=False)
parser.add_argument('--retune', action='store_true', default=False)
parser.add_argument('--dataset_path', type=str, default='data')
parser.add_argument('--model_path', type=str, default='results_model')
parser.add_argument('--evaluate_option', type=str, default='best-val')
parser.add_argument('--n_trials', type=int, default=default_args['n_trials'])
parser.add_argument('--seed_num', type=int, default=default_args['seed_num'])
parser.add_argument('--gpu', default=default_args['gpu'])
parser.add_argument('--tune', action='store_true', default=default_args['tune'])
parser.add_argument('--retune', action='store_true', default=default_args['retune'])
parser.add_argument('--dataset_path', type=str, default=default_args['dataset_path'])
parser.add_argument('--model_path', type=str, default=default_args['model_path'])
parser.add_argument('--evaluate_option', type=str, default=default_args['evaluate_option'])
args = parser.parse_args()

set_gpu(args.gpu)
Expand Down Expand Up @@ -257,35 +258,37 @@ def get_deep_args():

parser = argparse.ArgumentParser()
# basic parameters
parser.add_argument('--dataset', type=str, default='Bank_Customer_Churn_Dataset')
with open('configs/deep_configs.json','r') as file:
default_args = json.load(file)
parser.add_argument('--dataset', type=str, default=default_args['dataset'])
parser.add_argument('--model_type', type=str,
default='mlp',
default=default_args['model_type'],
choices=['mlp', 'resnet', 'ftt', 'node', 'autoint',
'tabpfn', 'tangos', 'saint', 'tabcaps', 'tabnet',
'snn','ptarl','danets','dcn2','tabtransformer',
'dnnr', 'switchtab', 'grownet','tabr','modernNCA']) #

# optimization parameters
parser.add_argument('--max_epoch', type=int, default=200)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--normalization', type=str, default='standard', choices=['none', 'standard', 'minmax', 'quantile', 'maxabs', 'power', 'robust'])
parser.add_argument('--num_nan_policy', type=str, default='mean', choices=['mean', 'median'])
parser.add_argument('--cat_nan_policy', type=str, default='new', choices=['new', 'most_frequent'])
parser.add_argument('--cat_policy', type=str, default='ohe', choices=['indices', 'ordinal', 'ohe', 'binary', 'hash', 'loo', 'target', 'catboost','tabr_ohe'])
parser.add_argument('--num_policy',type=str, default='none',choices=['none','Q_PLE','T_PLE','Q_Unary','T_Unary','Q_bins','T_bins','Q_Johnson','T_Johnson'])
parser.add_argument('--n_bins', type=int, default=2)
parser.add_argument('--cat_min_frequency', type=float, default=0.0)
parser.add_argument('--max_epoch', type=int, default=default_args['max_epoch'])
parser.add_argument('--batch_size', type=int, default=default_args['batch_size'])
parser.add_argument('--normalization', type=str, default=default_args['normalization'], choices=['none', 'standard', 'minmax', 'quantile', 'maxabs', 'power', 'robust'])
parser.add_argument('--num_nan_policy', type=str, default=default_args['num_nan_policy'], choices=['mean', 'median'])
parser.add_argument('--cat_nan_policy', type=str, default=default_args['cat_nan_policy'], choices=['new', 'most_frequent'])
parser.add_argument('--cat_policy', type=str, default=default_args['cat_policy'], choices=['indices', 'ordinal', 'ohe', 'binary', 'hash', 'loo', 'target', 'catboost','tabr_ohe'])
parser.add_argument('--num_policy',type=str, default=default_args['num_policy'],choices=['none','Q_PLE','T_PLE','Q_Unary','T_Unary','Q_bins','T_bins','Q_Johnson','T_Johnson'])
parser.add_argument('--n_bins', type=int, default=default_args['n_bins'])
parser.add_argument('--cat_min_frequency', type=float, default=default_args['cat_min_frequency'])

# other choices
parser.add_argument('--n_trials', type=int, default=50)
parser.add_argument('--seed_num', type=int, default=10)
parser.add_argument('--workers', type=int, default=0)
parser.add_argument('--gpu', default='0')
parser.add_argument('--tune', action='store_true', default=False)
parser.add_argument('--retune', action='store_true', default=False)
parser.add_argument('--evaluate_option', type=str, default='best-val')
parser.add_argument('--dataset_path', type=str, default='data')
parser.add_argument('--model_path', type=str, default='results_model')
parser.add_argument('--n_trials', type=int, default=default_args['n_trials'])
parser.add_argument('--seed_num', type=int, default=default_args['seed_num'])
parser.add_argument('--workers', type=int, default=default_args['workers'])
parser.add_argument('--gpu', default=default_args['gpu'])
parser.add_argument('--tune', action='store_true', default=default_args['tune'])
parser.add_argument('--retune', action='store_true', default=default_args['retune'])
parser.add_argument('--evaluate_option', type=str, default=default_args['evaluate_option'])
parser.add_argument('--dataset_path', type=str, default=default_args['dataset_path'])
parser.add_argument('--model_path', type=str, default=default_args['model_path'])
args = parser.parse_args()

set_gpu(args.gpu)
Expand All @@ -305,7 +308,6 @@ def get_deep_args():
mkdir(args.save_path)

# load config parameters
args.seed = 0
config_default_path = os.path.join('configs','default',args.model_type+'.json')
config_opt_path = os.path.join('configs','opt_space',args.model_type+'.json')
with open(config_default_path,'r') as file:
Expand All @@ -314,6 +316,8 @@ def get_deep_args():
with open(config_opt_path,'r') as file:
opt_space = json.load(file)
args.config = default_para[args.model_type]

args.seed = 0
set_seeds(args.seed)
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
Expand Down

0 comments on commit 94a4fe5

Please sign in to comment.