forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_logistic.py
2030 lines (1724 loc) · 76.3 KB
/
_logistic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Logistic Regression
"""
# Author: Gael Varoquaux <[email protected]>
# Fabian Pedregosa <[email protected]>
# Alexandre Gramfort <[email protected]>
# Manoj Kumar <[email protected]>
# Lars Buitinck
# Simon Wu <[email protected]>
# Arthur Mensch <[email protected]
import numbers
import warnings
import numpy as np
from scipy import optimize
from joblib import Parallel, effective_n_jobs
from ._base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator
from ._linear_loss import LinearModelLoss
from ._sag import sag_solver
from .._loss.loss import HalfBinomialLoss, HalfMultinomialLoss
from ..preprocessing import LabelEncoder, LabelBinarizer
from ..svm._base import _fit_liblinear
from ..utils import check_array, check_consistent_length, compute_class_weight
from ..utils import check_random_state
from ..utils.extmath import softmax
from ..utils.extmath import row_norms
from ..utils.optimize import _newton_cg, _check_optimize_result
from ..utils.validation import check_is_fitted, _check_sample_weight
from ..utils.multiclass import check_classification_targets
from ..utils.fixes import delayed
from ..model_selection import check_cv
from ..metrics import get_scorer
_LOGISTIC_SOLVER_CONVERGENCE_MSG = (
"Please also refer to the documentation for alternative solver options:\n"
" https://scikit-learn.org/stable/modules/linear_model.html"
"#logistic-regression"
)
def _check_solver(solver, penalty, dual):
all_solvers = ["liblinear", "newton-cg", "lbfgs", "sag", "saga"]
if solver not in all_solvers:
raise ValueError(
"Logistic Regression supports only solvers in %s, got %s."
% (all_solvers, solver)
)
all_penalties = ["l1", "l2", "elasticnet", "none"]
if penalty not in all_penalties:
raise ValueError(
"Logistic Regression supports only penalties in %s, got %s."
% (all_penalties, penalty)
)
if solver not in ["liblinear", "saga"] and penalty not in ("l2", "none"):
raise ValueError(
"Solver %s supports only 'l2' or 'none' penalties, got %s penalty."
% (solver, penalty)
)
if solver != "liblinear" and dual:
raise ValueError(
"Solver %s supports only dual=False, got dual=%s" % (solver, dual)
)
if penalty == "elasticnet" and solver != "saga":
raise ValueError(
"Only 'saga' solver supports elasticnet penalty, got solver={}.".format(
solver
)
)
if solver == "liblinear" and penalty == "none":
raise ValueError("penalty='none' is not supported for the liblinear solver")
return solver
def _check_multi_class(multi_class, solver, n_classes):
if multi_class == "auto":
if solver == "liblinear":
multi_class = "ovr"
elif n_classes > 2:
multi_class = "multinomial"
else:
multi_class = "ovr"
if multi_class not in ("multinomial", "ovr"):
raise ValueError(
"multi_class should be 'multinomial', 'ovr' or 'auto'. Got %s."
% multi_class
)
if multi_class == "multinomial" and solver == "liblinear":
raise ValueError("Solver %s does not support a multinomial backend." % solver)
return multi_class
def _logistic_regression_path(
X,
y,
pos_class=None,
Cs=10,
fit_intercept=True,
max_iter=100,
tol=1e-4,
verbose=0,
solver="lbfgs",
coef=None,
class_weight=None,
dual=False,
penalty="l2",
intercept_scaling=1.0,
multi_class="auto",
random_state=None,
check_input=True,
max_squared_sum=None,
sample_weight=None,
l1_ratio=None,
n_threads=1,
):
"""Compute a Logistic Regression model for a list of regularization
parameters.
This is an implementation that uses the result of the previous model
to speed up computations along the set of solutions, making it faster
than sequentially calling LogisticRegression for the different parameters.
Note that there will be no speedup with liblinear solver, since it does
not handle warm-starting.
Read more in the :ref:`User Guide <logistic_regression>`.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input data.
y : array-like of shape (n_samples,) or (n_samples, n_targets)
Input data, target values.
pos_class : int, default=None
The class with respect to which we perform a one-vs-all fit.
If None, then it is assumed that the given problem is binary.
Cs : int or array-like of shape (n_cs,), default=10
List of values for the regularization parameter or integer specifying
the number of regularization parameters that should be used. In this
case, the parameters will be chosen in a logarithmic scale between
1e-4 and 1e4.
fit_intercept : bool, default=True
Whether to fit an intercept for the model. In this case the shape of
the returned array is (n_cs, n_features + 1).
max_iter : int, default=100
Maximum number of iterations for the solver.
tol : float, default=1e-4
Stopping criterion. For the newton-cg and lbfgs solvers, the iteration
will stop when ``max{|g_i | i = 1, ..., n} <= tol``
where ``g_i`` is the i-th component of the gradient.
verbose : int, default=0
For the liblinear and lbfgs solvers set verbose to any positive
number for verbosity.
solver : {'lbfgs', 'newton-cg', 'liblinear', 'sag', 'saga'}, \
default='lbfgs'
Numerical solver to use.
coef : array-like of shape (n_features,), default=None
Initialization value for coefficients of logistic regression.
Useless for liblinear solver.
class_weight : dict or 'balanced', default=None
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one.
The "balanced" mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``.
Note that these weights will be multiplied with sample_weight (passed
through the fit method) if sample_weight is specified.
dual : bool, default=False
Dual or primal formulation. Dual formulation is only implemented for
l2 penalty with liblinear solver. Prefer dual=False when
n_samples > n_features.
penalty : {'l1', 'l2', 'elasticnet'}, default='l2'
Used to specify the norm used in the penalization. The 'newton-cg',
'sag' and 'lbfgs' solvers support only l2 penalties. 'elasticnet' is
only supported by the 'saga' solver.
intercept_scaling : float, default=1.
Useful only when the solver 'liblinear' is used
and self.fit_intercept is set to True. In this case, x becomes
[x, self.intercept_scaling],
i.e. a "synthetic" feature with constant value equal to
intercept_scaling is appended to the instance vector.
The intercept becomes ``intercept_scaling * synthetic_feature_weight``.
Note! the synthetic feature weight is subject to l1/l2 regularization
as all other features.
To lessen the effect of regularization on synthetic feature weight
(and therefore on the intercept) intercept_scaling has to be increased.
multi_class : {'ovr', 'multinomial', 'auto'}, default='auto'
If the option chosen is 'ovr', then a binary problem is fit for each
label. For 'multinomial' the loss minimised is the multinomial loss fit
across the entire probability distribution, *even when the data is
binary*. 'multinomial' is unavailable when solver='liblinear'.
'auto' selects 'ovr' if the data is binary, or if solver='liblinear',
and otherwise selects 'multinomial'.
.. versionadded:: 0.18
Stochastic Average Gradient descent solver for 'multinomial' case.
.. versionchanged:: 0.22
Default changed from 'ovr' to 'auto' in 0.22.
random_state : int, RandomState instance, default=None
Used when ``solver`` == 'sag', 'saga' or 'liblinear' to shuffle the
data. See :term:`Glossary <random_state>` for details.
check_input : bool, default=True
If False, the input arrays X and y will not be checked.
max_squared_sum : float, default=None
Maximum squared sum of X over samples. Used only in SAG solver.
If None, it will be computed, going through all the samples.
The value should be precomputed to speed up cross validation.
sample_weight : array-like of shape(n_samples,), default=None
Array of weights that are assigned to individual samples.
If not provided, then each sample is given unit weight.
l1_ratio : float, default=None
The Elastic-Net mixing parameter, with ``0 <= l1_ratio <= 1``. Only
used if ``penalty='elasticnet'``. Setting ``l1_ratio=0`` is equivalent
to using ``penalty='l2'``, while setting ``l1_ratio=1`` is equivalent
to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a
combination of L1 and L2.
n_threads : int, default=1
Number of OpenMP threads to use.
Returns
-------
coefs : ndarray of shape (n_cs, n_features) or (n_cs, n_features + 1)
List of coefficients for the Logistic Regression model. If
fit_intercept is set to True then the second dimension will be
n_features + 1, where the last item represents the intercept. For
``multiclass='multinomial'``, the shape is (n_classes, n_cs,
n_features) or (n_classes, n_cs, n_features + 1).
Cs : ndarray
Grid of Cs used for cross-validation.
n_iter : array of shape (n_cs,)
Actual number of iteration for each Cs.
Notes
-----
You might get slightly different results with the solver liblinear than
with the others since this uses LIBLINEAR which penalizes the intercept.
.. versionchanged:: 0.19
The "copy" parameter was removed.
"""
if isinstance(Cs, numbers.Integral):
Cs = np.logspace(-4, 4, Cs)
solver = _check_solver(solver, penalty, dual)
# Preprocessing.
if check_input:
X = check_array(
X,
accept_sparse="csr",
dtype=np.float64,
accept_large_sparse=solver not in ["liblinear", "sag", "saga"],
)
y = check_array(y, ensure_2d=False, dtype=None)
check_consistent_length(X, y)
_, n_features = X.shape
classes = np.unique(y)
random_state = check_random_state(random_state)
multi_class = _check_multi_class(multi_class, solver, len(classes))
if pos_class is None and multi_class != "multinomial":
if classes.size > 2:
raise ValueError("To fit OvR, use the pos_class argument")
# np.unique(y) gives labels in sorted order.
pos_class = classes[1]
# If sample weights exist, convert them to array (support for lists)
# and check length
# Otherwise set them to 1 for all examples
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype, copy=True)
# If class_weights is a dict (provided by the user), the weights
# are assigned to the original labels. If it is "balanced", then
# the class_weights are assigned after masking the labels with a OvR.
le = LabelEncoder()
if isinstance(class_weight, dict) or multi_class == "multinomial":
class_weight_ = compute_class_weight(class_weight, classes=classes, y=y)
sample_weight *= class_weight_[le.fit_transform(y)]
# For doing a ovr, we need to mask the labels first. for the
# multinomial case this is not necessary.
if multi_class == "ovr":
w0 = np.zeros(n_features + int(fit_intercept), dtype=X.dtype)
mask = y == pos_class
y_bin = np.ones(y.shape, dtype=X.dtype)
if solver in ["lbfgs", "newton-cg"]:
# HalfBinomialLoss, used for those solvers, represents y in [0, 1] instead
# of in [-1, 1].
mask_classes = np.array([0, 1])
y_bin[~mask] = 0.0
else:
mask_classes = np.array([-1, 1])
y_bin[~mask] = -1.0
# for compute_class_weight
if class_weight == "balanced":
class_weight_ = compute_class_weight(
class_weight, classes=mask_classes, y=y_bin
)
sample_weight *= class_weight_[le.fit_transform(y_bin)]
else:
if solver in ["sag", "saga", "lbfgs", "newton-cg"]:
# SAG, lbfgs and newton-cg multinomial solvers need LabelEncoder,
# not LabelBinarizer, i.e. y as a 1d-array of integers.
# LabelEncoder also saves memory compared to LabelBinarizer, especially
# when n_classes is large.
le = LabelEncoder()
Y_multi = le.fit_transform(y).astype(X.dtype, copy=False)
else:
# For liblinear solver, apply LabelBinarizer, i.e. y is one-hot encoded.
lbin = LabelBinarizer()
Y_multi = lbin.fit_transform(y)
if Y_multi.shape[1] == 1:
Y_multi = np.hstack([1 - Y_multi, Y_multi])
w0 = np.zeros(
(classes.size, n_features + int(fit_intercept)), order="F", dtype=X.dtype
)
if coef is not None:
# it must work both giving the bias term and not
if multi_class == "ovr":
if coef.size not in (n_features, w0.size):
raise ValueError(
"Initialization coef is of shape %d, expected shape %d or %d"
% (coef.size, n_features, w0.size)
)
w0[: coef.size] = coef
else:
# For binary problems coef.shape[0] should be 1, otherwise it
# should be classes.size.
n_classes = classes.size
if n_classes == 2:
n_classes = 1
if coef.shape[0] != n_classes or coef.shape[1] not in (
n_features,
n_features + 1,
):
raise ValueError(
"Initialization coef is of shape (%d, %d), expected "
"shape (%d, %d) or (%d, %d)"
% (
coef.shape[0],
coef.shape[1],
classes.size,
n_features,
classes.size,
n_features + 1,
)
)
if n_classes == 1:
w0[0, : coef.shape[1]] = -coef
w0[1, : coef.shape[1]] = coef
else:
w0[:, : coef.shape[1]] = coef
if multi_class == "multinomial":
if solver in ["lbfgs", "newton-cg"]:
# scipy.optimize.minimize and newton-cg accept only ravelled parameters,
# i.e. 1d-arrays. LinearModelLoss expects classes to be contiguous and
# reconstructs the 2d-array via w0.reshape((n_classes, -1), order="F").
# As w0 is F-contiguous, ravel(order="F") also avoids a copy.
w0 = w0.ravel(order="F")
loss = LinearModelLoss(
base_loss=HalfMultinomialLoss(n_classes=classes.size),
fit_intercept=fit_intercept,
)
target = Y_multi
if solver in "lbfgs":
func = loss.loss_gradient
elif solver == "newton-cg":
func = loss.loss
grad = loss.gradient
hess = loss.gradient_hessian_product # hess = [gradient, hessp]
warm_start_sag = {"coef": w0.T}
else:
target = y_bin
if solver == "lbfgs":
loss = LinearModelLoss(
base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept
)
func = loss.loss_gradient
elif solver == "newton-cg":
loss = LinearModelLoss(
base_loss=HalfBinomialLoss(), fit_intercept=fit_intercept
)
func = loss.loss
grad = loss.gradient
hess = loss.gradient_hessian_product # hess = [gradient, hessp]
warm_start_sag = {"coef": np.expand_dims(w0, axis=1)}
coefs = list()
n_iter = np.zeros(len(Cs), dtype=np.int32)
for i, C in enumerate(Cs):
if solver == "lbfgs":
l2_reg_strength = 1.0 / C
iprint = [-1, 50, 1, 100, 101][
np.searchsorted(np.array([0, 1, 2, 3]), verbose)
]
opt_res = optimize.minimize(
func,
w0,
method="L-BFGS-B",
jac=True,
args=(X, target, sample_weight, l2_reg_strength, n_threads),
options={"iprint": iprint, "gtol": tol, "maxiter": max_iter},
)
n_iter_i = _check_optimize_result(
solver,
opt_res,
max_iter,
extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG,
)
w0, loss = opt_res.x, opt_res.fun
elif solver == "newton-cg":
l2_reg_strength = 1.0 / C
args = (X, target, sample_weight, l2_reg_strength, n_threads)
w0, n_iter_i = _newton_cg(
hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol
)
elif solver == "liblinear":
coef_, intercept_, n_iter_i, = _fit_liblinear(
X,
target,
C,
fit_intercept,
intercept_scaling,
None,
penalty,
dual,
verbose,
max_iter,
tol,
random_state,
sample_weight=sample_weight,
)
if fit_intercept:
w0 = np.concatenate([coef_.ravel(), intercept_])
else:
w0 = coef_.ravel()
elif solver in ["sag", "saga"]:
if multi_class == "multinomial":
target = target.astype(X.dtype, copy=False)
loss = "multinomial"
else:
loss = "log"
# alpha is for L2-norm, beta is for L1-norm
if penalty == "l1":
alpha = 0.0
beta = 1.0 / C
elif penalty == "l2":
alpha = 1.0 / C
beta = 0.0
else: # Elastic-Net penalty
alpha = (1.0 / C) * (1 - l1_ratio)
beta = (1.0 / C) * l1_ratio
w0, n_iter_i, warm_start_sag = sag_solver(
X,
target,
sample_weight,
loss,
alpha,
beta,
max_iter,
tol,
verbose,
random_state,
False,
max_squared_sum,
warm_start_sag,
is_saga=(solver == "saga"),
)
else:
raise ValueError(
"solver must be one of {'liblinear', 'lbfgs', "
"'newton-cg', 'sag'}, got '%s' instead" % solver
)
if multi_class == "multinomial":
n_classes = max(2, classes.size)
if solver in ["lbfgs", "newton-cg"]:
multi_w0 = np.reshape(w0, (n_classes, -1), order="F")
else:
multi_w0 = w0
if n_classes == 2:
multi_w0 = multi_w0[1][np.newaxis, :]
coefs.append(multi_w0.copy())
else:
coefs.append(w0.copy())
n_iter[i] = n_iter_i
return np.array(coefs), np.array(Cs), n_iter
# helper function for LogisticCV
def _log_reg_scoring_path(
X,
y,
train,
test,
pos_class=None,
Cs=10,
scoring=None,
fit_intercept=False,
max_iter=100,
tol=1e-4,
class_weight=None,
verbose=0,
solver="lbfgs",
penalty="l2",
dual=False,
intercept_scaling=1.0,
multi_class="auto",
random_state=None,
max_squared_sum=None,
sample_weight=None,
l1_ratio=None,
):
"""Computes scores across logistic_regression_path
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,) or (n_samples, n_targets)
Target labels.
train : list of indices
The indices of the train set.
test : list of indices
The indices of the test set.
pos_class : int, default=None
The class with respect to which we perform a one-vs-all fit.
If None, then it is assumed that the given problem is binary.
Cs : int or list of floats, default=10
Each of the values in Cs describes the inverse of
regularization strength. If Cs is as an int, then a grid of Cs
values are chosen in a logarithmic scale between 1e-4 and 1e4.
If not provided, then a fixed set of values for Cs are used.
scoring : callable, default=None
A string (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``. For a list of scoring functions
that can be used, look at :mod:`sklearn.metrics`. The
default scoring option used is accuracy_score.
fit_intercept : bool, default=False
If False, then the bias term is set to zero. Else the last
term of each coef_ gives us the intercept.
max_iter : int, default=100
Maximum number of iterations for the solver.
tol : float, default=1e-4
Tolerance for stopping criteria.
class_weight : dict or 'balanced', default=None
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one.
The "balanced" mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``
Note that these weights will be multiplied with sample_weight (passed
through the fit method) if sample_weight is specified.
verbose : int, default=0
For the liblinear and lbfgs solvers set verbose to any positive
number for verbosity.
solver : {'lbfgs', 'newton-cg', 'liblinear', 'sag', 'saga'}, \
default='lbfgs'
Decides which solver to use.
penalty : {'l1', 'l2', 'elasticnet'}, default='l2'
Used to specify the norm used in the penalization. The 'newton-cg',
'sag' and 'lbfgs' solvers support only l2 penalties. 'elasticnet' is
only supported by the 'saga' solver.
dual : bool, default=False
Dual or primal formulation. Dual formulation is only implemented for
l2 penalty with liblinear solver. Prefer dual=False when
n_samples > n_features.
intercept_scaling : float, default=1.
Useful only when the solver 'liblinear' is used
and self.fit_intercept is set to True. In this case, x becomes
[x, self.intercept_scaling],
i.e. a "synthetic" feature with constant value equals to
intercept_scaling is appended to the instance vector.
The intercept becomes intercept_scaling * synthetic feature weight
Note! the synthetic feature weight is subject to l1/l2 regularization
as all other features.
To lessen the effect of regularization on synthetic feature weight
(and therefore on the intercept) intercept_scaling has to be increased.
multi_class : {'auto', 'ovr', 'multinomial'}, default='auto'
If the option chosen is 'ovr', then a binary problem is fit for each
label. For 'multinomial' the loss minimised is the multinomial loss fit
across the entire probability distribution, *even when the data is
binary*. 'multinomial' is unavailable when solver='liblinear'.
random_state : int, RandomState instance, default=None
Used when ``solver`` == 'sag', 'saga' or 'liblinear' to shuffle the
data. See :term:`Glossary <random_state>` for details.
max_squared_sum : float, default=None
Maximum squared sum of X over samples. Used only in SAG solver.
If None, it will be computed, going through all the samples.
The value should be precomputed to speed up cross validation.
sample_weight : array-like of shape(n_samples,), default=None
Array of weights that are assigned to individual samples.
If not provided, then each sample is given unit weight.
l1_ratio : float, default=None
The Elastic-Net mixing parameter, with ``0 <= l1_ratio <= 1``. Only
used if ``penalty='elasticnet'``. Setting ``l1_ratio=0`` is equivalent
to using ``penalty='l2'``, while setting ``l1_ratio=1`` is equivalent
to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a
combination of L1 and L2.
Returns
-------
coefs : ndarray of shape (n_cs, n_features) or (n_cs, n_features + 1)
List of coefficients for the Logistic Regression model. If
fit_intercept is set to True then the second dimension will be
n_features + 1, where the last item represents the intercept.
Cs : ndarray
Grid of Cs used for cross-validation.
scores : ndarray of shape (n_cs,)
Scores obtained for each Cs.
n_iter : ndarray of shape(n_cs,)
Actual number of iteration for each Cs.
"""
X_train = X[train]
X_test = X[test]
y_train = y[train]
y_test = y[test]
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)
sample_weight = sample_weight[train]
coefs, Cs, n_iter = _logistic_regression_path(
X_train,
y_train,
Cs=Cs,
l1_ratio=l1_ratio,
fit_intercept=fit_intercept,
solver=solver,
max_iter=max_iter,
class_weight=class_weight,
pos_class=pos_class,
multi_class=multi_class,
tol=tol,
verbose=verbose,
dual=dual,
penalty=penalty,
intercept_scaling=intercept_scaling,
random_state=random_state,
check_input=False,
max_squared_sum=max_squared_sum,
sample_weight=sample_weight,
)
log_reg = LogisticRegression(solver=solver, multi_class=multi_class)
# The score method of Logistic Regression has a classes_ attribute.
if multi_class == "ovr":
log_reg.classes_ = np.array([-1, 1])
elif multi_class == "multinomial":
log_reg.classes_ = np.unique(y_train)
else:
raise ValueError(
"multi_class should be either multinomial or ovr, got %d" % multi_class
)
if pos_class is not None:
mask = y_test == pos_class
y_test = np.ones(y_test.shape, dtype=np.float64)
y_test[~mask] = -1.0
scores = list()
scoring = get_scorer(scoring)
for w in coefs:
if multi_class == "ovr":
w = w[np.newaxis, :]
if fit_intercept:
log_reg.coef_ = w[:, :-1]
log_reg.intercept_ = w[:, -1]
else:
log_reg.coef_ = w
log_reg.intercept_ = 0.0
if scoring is None:
scores.append(log_reg.score(X_test, y_test))
else:
scores.append(scoring(log_reg, X_test, y_test))
return coefs, Cs, np.array(scores), n_iter
class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator):
"""
Logistic Regression (aka logit, MaxEnt) classifier.
In the multiclass case, the training algorithm uses the one-vs-rest (OvR)
scheme if the 'multi_class' option is set to 'ovr', and uses the
cross-entropy loss if the 'multi_class' option is set to 'multinomial'.
(Currently the 'multinomial' option is supported only by the 'lbfgs',
'sag', 'saga' and 'newton-cg' solvers.)
This class implements regularized logistic regression using the
'liblinear' library, 'newton-cg', 'sag', 'saga' and 'lbfgs' solvers. **Note
that regularization is applied by default**. It can handle both dense
and sparse input. Use C-ordered arrays or CSR matrices containing 64-bit
floats for optimal performance; any other input format will be converted
(and copied).
The 'newton-cg', 'sag', and 'lbfgs' solvers support only L2 regularization
with primal formulation, or no regularization. The 'liblinear' solver
supports both L1 and L2 regularization, with a dual formulation only for
the L2 penalty. The Elastic-Net regularization is only supported by the
'saga' solver.
Read more in the :ref:`User Guide <logistic_regression>`.
Parameters
----------
penalty : {'l1', 'l2', 'elasticnet', 'none'}, default='l2'
Specify the norm of the penalty:
- `'none'`: no penalty is added;
- `'l2'`: add a L2 penalty term and it is the default choice;
- `'l1'`: add a L1 penalty term;
- `'elasticnet'`: both L1 and L2 penalty terms are added.
.. warning::
Some penalties may not work with some solvers. See the parameter
`solver` below, to know the compatibility between the penalty and
solver.
.. versionadded:: 0.19
l1 penalty with SAGA solver (allowing 'multinomial' + L1)
dual : bool, default=False
Dual or primal formulation. Dual formulation is only implemented for
l2 penalty with liblinear solver. Prefer dual=False when
n_samples > n_features.
tol : float, default=1e-4
Tolerance for stopping criteria.
C : float, default=1.0
Inverse of regularization strength; must be a positive float.
Like in support vector machines, smaller values specify stronger
regularization.
fit_intercept : bool, default=True
Specifies if a constant (a.k.a. bias or intercept) should be
added to the decision function.
intercept_scaling : float, default=1
Useful only when the solver 'liblinear' is used
and self.fit_intercept is set to True. In this case, x becomes
[x, self.intercept_scaling],
i.e. a "synthetic" feature with constant value equal to
intercept_scaling is appended to the instance vector.
The intercept becomes ``intercept_scaling * synthetic_feature_weight``.
Note! the synthetic feature weight is subject to l1/l2 regularization
as all other features.
To lessen the effect of regularization on synthetic feature weight
(and therefore on the intercept) intercept_scaling has to be increased.
class_weight : dict or 'balanced', default=None
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one.
The "balanced" mode uses the values of y to automatically adjust
weights inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``.
Note that these weights will be multiplied with sample_weight (passed
through the fit method) if sample_weight is specified.
.. versionadded:: 0.17
*class_weight='balanced'*
random_state : int, RandomState instance, default=None
Used when ``solver`` == 'sag', 'saga' or 'liblinear' to shuffle the
data. See :term:`Glossary <random_state>` for details.
solver : {'newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'}, \
default='lbfgs'
Algorithm to use in the optimization problem. Default is 'lbfgs'.
To choose a solver, you might want to consider the following aspects:
- For small datasets, 'liblinear' is a good choice, whereas 'sag'
and 'saga' are faster for large ones;
- For multiclass problems, only 'newton-cg', 'sag', 'saga' and
'lbfgs' handle multinomial loss;
- 'liblinear' is limited to one-versus-rest schemes.
.. warning::
The choice of the algorithm depends on the penalty chosen:
Supported penalties by solver:
- 'newton-cg' - ['l2', 'none']
- 'lbfgs' - ['l2', 'none']
- 'liblinear' - ['l1', 'l2']
- 'sag' - ['l2', 'none']
- 'saga' - ['elasticnet', 'l1', 'l2', 'none']
.. note::
'sag' and 'saga' fast convergence is only guaranteed on
features with approximately the same scale. You can
preprocess the data with a scaler from :mod:`sklearn.preprocessing`.
.. seealso::
Refer to the User Guide for more information regarding
:class:`LogisticRegression` and more specifically the
`Table <https://scikit-learn.org/dev/modules/linear_model.html#logistic-regression>`_
summarazing solver/penalty supports.
.. versionadded:: 0.17
Stochastic Average Gradient descent solver.
.. versionadded:: 0.19
SAGA solver.
.. versionchanged:: 0.22
The default solver changed from 'liblinear' to 'lbfgs' in 0.22.
max_iter : int, default=100
Maximum number of iterations taken for the solvers to converge.
multi_class : {'auto', 'ovr', 'multinomial'}, default='auto'
If the option chosen is 'ovr', then a binary problem is fit for each
label. For 'multinomial' the loss minimised is the multinomial loss fit
across the entire probability distribution, *even when the data is
binary*. 'multinomial' is unavailable when solver='liblinear'.
'auto' selects 'ovr' if the data is binary, or if solver='liblinear',
and otherwise selects 'multinomial'.
.. versionadded:: 0.18
Stochastic Average Gradient descent solver for 'multinomial' case.
.. versionchanged:: 0.22
Default changed from 'ovr' to 'auto' in 0.22.
verbose : int, default=0
For the liblinear and lbfgs solvers set verbose to any positive
number for verbosity.
warm_start : bool, default=False
When set to True, reuse the solution of the previous call to fit as
initialization, otherwise, just erase the previous solution.
Useless for liblinear solver. See :term:`the Glossary <warm_start>`.
.. versionadded:: 0.17
*warm_start* to support *lbfgs*, *newton-cg*, *sag*, *saga* solvers.
n_jobs : int, default=None
Number of CPU cores used when parallelizing over classes if
multi_class='ovr'". This parameter is ignored when the ``solver`` is
set to 'liblinear' regardless of whether 'multi_class' is specified or
not. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
context. ``-1`` means using all processors.
See :term:`Glossary <n_jobs>` for more details.
l1_ratio : float, default=None
The Elastic-Net mixing parameter, with ``0 <= l1_ratio <= 1``. Only
used if ``penalty='elasticnet'``. Setting ``l1_ratio=0`` is equivalent
to using ``penalty='l2'``, while setting ``l1_ratio=1`` is equivalent
to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a
combination of L1 and L2.
Attributes
----------
classes_ : ndarray of shape (n_classes, )
A list of class labels known to the classifier.
coef_ : ndarray of shape (1, n_features) or (n_classes, n_features)
Coefficient of the features in the decision function.
`coef_` is of shape (1, n_features) when the given problem is binary.
In particular, when `multi_class='multinomial'`, `coef_` corresponds
to outcome 1 (True) and `-coef_` corresponds to outcome 0 (False).
intercept_ : ndarray of shape (1,) or (n_classes,)
Intercept (a.k.a. bias) added to the decision function.
If `fit_intercept` is set to False, the intercept is set to zero.
`intercept_` is of shape (1,) when the given problem is binary.
In particular, when `multi_class='multinomial'`, `intercept_`
corresponds to outcome 1 (True) and `-intercept_` corresponds to
outcome 0 (False).
n_features_in_ : int
Number of features seen during :term:`fit`.
.. versionadded:: 0.24
feature_names_in_ : ndarray of shape (`n_features_in_`,)
Names of features seen during :term:`fit`. Defined only when `X`
has feature names that are all strings.
.. versionadded:: 1.0
n_iter_ : ndarray of shape (n_classes,) or (1, )
Actual number of iterations for all classes. If binary or multinomial,
it returns only 1 element. For liblinear solver, only the maximum
number of iteration across all classes is given.
.. versionchanged:: 0.20
In SciPy <= 1.0.0 the number of lbfgs iterations may exceed
``max_iter``. ``n_iter_`` will now report at most ``max_iter``.
See Also
--------
SGDClassifier : Incrementally trained logistic regression (when given
the parameter ``loss="log"``).
LogisticRegressionCV : Logistic regression with built-in cross validation.
Notes
-----
The underlying C implementation uses a random number generator to
select features when fitting the model. It is thus not uncommon,
to have slightly different results for the same input data. If
that happens, try with a smaller tol parameter.
Predict output may not match that of standalone liblinear in certain
cases. See :ref:`differences from liblinear <liblinear_differences>`
in the narrative documentation.
References
----------
L-BFGS-B -- Software for Large-scale Bound-constrained Optimization
Ciyou Zhu, Richard Byrd, Jorge Nocedal and Jose Luis Morales.
http://users.iems.northwestern.edu/~nocedal/lbfgsb.html
LIBLINEAR -- A Library for Large Linear Classification
https://www.csie.ntu.edu.tw/~cjlin/liblinear/
SAG -- Mark Schmidt, Nicolas Le Roux, and Francis Bach
Minimizing Finite Sums with the Stochastic Average Gradient
https://hal.inria.fr/hal-00860051/document