Skip to content

Commit

Permalink
[MRG] fix refit=False error in LogisticRegressionCV (scikit-learn#14087)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored and ogrisel committed Jun 14, 2019
1 parent 3ec339a commit 36b688e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ Version 0.21.3
Changelog
---------

:mod:`sklearn.linear_model`
...........................
- |Fix| Fixed a bug in :class:`linear_model.LogisticRegressionCV` where
``refit=False`` would fail depending on the ``'multiclass'`` and
``'penalty'`` parameters (regression introduced in 0.21). :pr:`14087` by
`Nicolas Hug`_.

.. _changes_0_21_2:

Expand Down
9 changes: 6 additions & 3 deletions sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2170,7 +2170,7 @@ def fit(self, X, y, sample_weight=None):
# Take the best scores across every fold and the average of
# all coefficients corresponding to the best scores.
best_indices = np.argmax(scores, axis=1)
if self.multi_class == 'ovr':
if multi_class == 'ovr':
w = np.mean([coefs_paths[i, best_indices[i], :]
for i in range(len(folds))], axis=0)
else:
Expand All @@ -2180,8 +2180,11 @@ def fit(self, X, y, sample_weight=None):
best_indices_C = best_indices % len(self.Cs_)
self.C_.append(np.mean(self.Cs_[best_indices_C]))

best_indices_l1 = best_indices // len(self.Cs_)
self.l1_ratio_.append(np.mean(l1_ratios_[best_indices_l1]))
if self.penalty == 'elasticnet':
best_indices_l1 = best_indices // len(self.Cs_)
self.l1_ratio_.append(np.mean(l1_ratios_[best_indices_l1]))
else:
self.l1_ratio_.append(None)

if multi_class == 'multinomial':
self.C_ = np.tile(self.C_, n_classes)
Expand Down
12 changes: 8 additions & 4 deletions sklearn/linear_model/tests/test_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,8 +1532,9 @@ def test_LogisticRegressionCV_GridSearchCV_elastic_net_ovr():
assert (lrcv.predict(X_test) == gs.predict(X_test)).mean() >= .8


@pytest.mark.parametrize('multi_class', ('ovr', 'multinomial'))
def test_LogisticRegressionCV_no_refit(multi_class):
@pytest.mark.parametrize('penalty', ('l2', 'elasticnet'))
@pytest.mark.parametrize('multi_class', ('ovr', 'multinomial', 'auto'))
def test_LogisticRegressionCV_no_refit(penalty, multi_class):
# Test LogisticRegressionCV attribute shapes when refit is False

n_classes = 3
Expand All @@ -1543,9 +1544,12 @@ def test_LogisticRegressionCV_no_refit(multi_class):
random_state=0)

Cs = np.logspace(-4, 4, 3)
l1_ratios = np.linspace(0, 1, 2)
if penalty == 'elasticnet':
l1_ratios = np.linspace(0, 1, 2)
else:
l1_ratios = None

lrcv = LogisticRegressionCV(penalty='elasticnet', Cs=Cs, solver='saga',
lrcv = LogisticRegressionCV(penalty=penalty, Cs=Cs, solver='saga',
l1_ratios=l1_ratios, random_state=0,
multi_class=multi_class, refit=False)
lrcv.fit(X, y)
Expand Down

0 comments on commit 36b688e

Please sign in to comment.