Skip to content

Commit

Permalink
add dummy method
Browse files Browse the repository at this point in the history
  • Loading branch information
6sy666 committed May 27, 2024
1 parent cb5affd commit 7a2afdc
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 3 deletions.
3 changes: 2 additions & 1 deletion default_para.json
Original file line number Diff line number Diff line change
Expand Up @@ -338,5 +338,6 @@
},
"fit": {}
},
"NaiveBayes": {"model": {}, "fit": {}}
"NaiveBayes": {"model": {}, "fit": {}},
"dummy":{"model": {}, "fit": {}}
}
59 changes: 59 additions & 0 deletions model/classical_methods/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from model.classical_methods.base import classical_methods
from copy import deepcopy
import os.path as ops
import pickle
import time

class DummyMethod(classical_methods):
def __init__(self, args, is_regression):
super().__init__(args, is_regression)
assert(args.cat_policy != 'indices')

def construct_model(self, model_config = None):
if model_config is None:
model_config = self.args.config['model']
from sklearn.dummy import DummyClassifier, DummyRegressor
if self.is_regression:
self.model = DummyRegressor(strategy='mean')
else:
self.model = DummyClassifier()


def fit(self, N, C, y, info, train=True, config=None):
super().fit(N, C, y, info, train, config)
# if not train, skip the training process. such as load the checkpoint and directly predict the results
if not train:
return
tic = time.time()
self.model.fit(self.N['train'], self.y['train'])
self.trlog['best_res'] = self.model.score(self.N['val'], self.y['val'])
time_cost = time.time() - tic
with open(ops.join(self.args.save_path , 'best-val-{}.pkl'.format(self.args.seed)), 'wb') as f:
pickle.dump(self.model, f)
return time_cost

def predict(self, N, C, y, info, model_name):
with open(ops.join(self.args.save_path , 'best-val-{}.pkl'.format(self.args.seed)), 'rb') as f:
self.model = pickle.load(f)
self.data_format(False, N, C, y)
test_label = self.y_test
test_logit = self.model.predict(self.N_test)
vres, metric_name = self.metric(test_logit, test_label, self.y_info)
return vres, metric_name, test_logit

def metric(self, predictions, labels, y_info):
from sklearn import metrics as skm
if self.is_regression:
mae = skm.mean_absolute_error(labels, predictions)
rmse = skm.mean_squared_error(labels, predictions) ** 0.5
r2 = skm.r2_score(labels, predictions)
if y_info['policy'] == 'mean_std':
mae *= y_info['std']
rmse *= y_info['std']
return (mae,r2,rmse), ("MAE", "R2", "RMSE")
else:
accuracy = skm.accuracy_score(labels, predictions)
avg_precision = skm.precision_score(labels, predictions, average='macro')
avg_recall = skm.recall_score(labels, predictions, average='macro')
f1_score = skm.f1_score(labels, predictions, average='binary' if self.is_binclass else 'macro')
return (accuracy, avg_precision, avg_recall, f1_score), ("Accuracy", "Avg_Precision", "Avg_Recall", "F1")
3 changes: 3 additions & 0 deletions model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,5 +265,8 @@ def modeltype_to_method(model):
elif model == 'svm':
from model.classical_methods.svm import SvmMethod
return SvmMethod
elif model == 'dummy':
from model.classical_methods.dummy import DummyMethod
return DummyMethod
else:
raise NotImplementedError("Model \"" + model + "\" not yet implemented")
4 changes: 2 additions & 2 deletions train_model_classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_args():
default='xgboost',
choices=['LogReg', 'NCM', 'RandomForest',
'xgboost', 'catboost', 'lightgbm',
'svm','knn', 'NaiveBayes'
'svm','knn', 'NaiveBayes',"dummy",
])

# optimization parameters
Expand All @@ -44,7 +44,7 @@ def get_args():
parser.add_argument('--gpu', default='0')
parser.add_argument('--tune', action='store_true', default=False)
parser.add_argument('--retune', action='store_true', default=False)
parser.add_argument('--dataset_path', type=str, default='data_cls_resplit')
parser.add_argument('--dataset_path', type=str, default='data')
parser.add_argument('--model_path', type=str, default='results_model')
parser.add_argument('--evaluate_option', type=str, default='best-val')
args = parser.parse_args()
Expand Down

0 comments on commit 7a2afdc

Please sign in to comment.