Skip to content

Commit

Permalink
auc added
Browse files Browse the repository at this point in the history
  • Loading branch information
PiotrekGa committed Aug 19, 2019
1 parent 415a523 commit d2f42ec
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions prunedcv/src.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ def cross_val_score(self,

if metric not in ['mse',
'mae',
'accuracy']:
'accuracy',
'auc']:
raise ValueError

if metric in ['mse',
Expand All @@ -389,7 +390,8 @@ def cross_val_score(self,
shuffle=shuffle,
random_state=random_state)

elif metric in ['accuracy']:
elif metric in ['accuracy',
'auc']:

kf = StratifiedKFold(n_splits=self.cv,
shuffle=shuffle,
Expand Down Expand Up @@ -426,6 +428,9 @@ def cross_val_score(self,
elif metric == 'accuracy':
self._add_split_value_and_prun(metrics.accuracy_score(y_test,
y_test_teor))
elif metric == 'auc':
self._add_split_value_and_prun(metrics.roc_auc_score(y_test,
y_test_teor))

self.prune = False
return self.cross_val_score_value
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name='prunedcv',
author='Piotr Gabryś',
author_email='[email protected]',
version='0.0.2',
version='0.0.3',
packages=['prunedcv',],
install_requires=[
'scikit-learn>=0.20.2',
Expand Down

0 comments on commit d2f42ec

Please sign in to comment.