Skip to content

Commit 75a94f5

Browse files
lorentzenchrogriselthomasjpfan
authored
ENH migrate GLMs / TweedieRegressor to linear loss (scikit-learn#22548)
Co-authored-by: Olivier Grisel <[email protected]> Co-authored-by: Thomas J. Fan <[email protected]>
1 parent d14fd82 commit 75a94f5

19 files changed

+694
-545
lines changed

doc/modules/linear_model.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,7 @@ reproductive exponential dispersion model (EDM) [11]_).
10321032

10331033
The minimization problem becomes:
10341034

1035-
.. math:: \min_{w} \frac{1}{2 n_{\text{samples}}} \sum_i d(y_i, \hat{y}_i) + \frac{\alpha}{2} ||w||_2,
1035+
.. math:: \min_{w} \frac{1}{2 n_{\text{samples}}} \sum_i d(y_i, \hat{y}_i) + \frac{\alpha}{2} ||w||_2^2,
10361036

10371037
where :math:`\alpha` is the L2 regularization penalty. When sample weights are
10381038
provided, the average becomes a weighted average.

doc/whats_new/v1.1.rst

+12
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,12 @@ Changelog
600600
:pr:`21808`, :pr:`20567` and :pr:`21814` by
601601
:user:`Christian Lorentzen <lorentzenchr>`.
602602

603+
- |Enhancement| :class:`~linear_model.GammaRegressor`,
604+
:class:`~linear_model.PoissonRegressor` and :class:`~linear_model.TweedieRegressor`
605+
are faster for ``solvers="lbfgs"``.
606+
:pr:`22548`, :pr:`21808` and :pr:`20567` by
607+
:user:`Christian Lorentzen <lorentzenchr>`.
608+
603609
- |Enhancement| Rename parameter `base_estimator` to `estimator` in
604610
:class:`linear_model.RANSACRegressor` to improve readability and consistency.
605611
`base_estimator` is deprecated and will be removed in 1.3.
@@ -633,6 +639,12 @@ Changelog
633639
sub-problem while now all of them are recorded. :pr:`21998` by
634640
:user:`Olivier Grisel <ogrisel>`.
635641

642+
- |Fix| The property `family` of :class:`linear_model.TweedieRegressor` is not
643+
validated in `__init__` anymore. Instead, this (private) property is deprecated in
644+
:class:`linear_model.GammaRegressor`, :class:`linear_model.PoissonRegressor` and
645+
:class:`linear_model.TweedieRegressor`, and will be removed in 1.3.
646+
:pr:`22548` by :user:`Christian Lorentzen <lorentzenchr>`.
647+
636648
- |Enhancement| :class:`linear_model.BayesianRidge` and
637649
:class:`linear_model.ARDRegression` now preserve float32 dtype. :pr:`9087` by
638650
:user:`Arthur Imbert <Henley13>` and :pr:`22525` by :user:`Meekail Zain <micky774>`.

sklearn/_loss/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
HalfPoissonLoss,
1111
HalfGammaLoss,
1212
HalfTweedieLoss,
13+
HalfTweedieLossIdentity,
1314
HalfBinomialLoss,
1415
HalfMultinomialLoss,
1516
)
@@ -22,6 +23,7 @@
2223
"HalfPoissonLoss",
2324
"HalfGammaLoss",
2425
"HalfTweedieLoss",
26+
"HalfTweedieLossIdentity",
2527
"HalfBinomialLoss",
2628
"HalfMultinomialLoss",
2729
]

sklearn/_loss/_loss.pxd

+7
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ cdef class CyHalfTweedieLoss(CyLossFunction):
6969
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
7070

7171

72+
cdef class CyHalfTweedieLossIdentity(CyLossFunction):
73+
cdef readonly double power # readonly makes it accessible from Python
74+
cdef double cy_loss(self, double y_true, double raw_prediction) nogil
75+
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil
76+
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
77+
78+
7279
cdef class CyHalfBinomialLoss(CyLossFunction):
7380
cdef double cy_loss(self, double y_true, double raw_prediction) nogil
7481
cdef double cy_gradient(self, double y_true, double raw_prediction) nogil

sklearn/_loss/_loss.pyx.tp

+124-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{{py:
22

33
"""
4-
Template file for easily generate loops over samples using Tempita
4+
Template file to easily generate loops over samples using Tempita
55
(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py).
66

77
Generated file: _loss.pyx
@@ -117,6 +117,28 @@ doc_HalfTweedieLoss = (
117117
"""
118118
)
119119

120+
doc_HalfTweedieLossIdentity = (
121+
"""Half Tweedie deviance loss with identity link.
122+
123+
Domain:
124+
y_true in real numbers if p <= 0
125+
y_true in non-negative real numbers if 0 < p < 2
126+
y_true in positive real numbers if p >= 2
127+
y_pred and power in positive real numbers, y_pred may be negative for p=0.
128+
129+
Link:
130+
y_pred = raw_prediction
131+
132+
Half Tweedie deviance with identity link and p=power is
133+
max(y_true, 0)**(2-p) / (1-p) / (2-p)
134+
- y_true * y_pred**(1-p) / (1-p)
135+
+ y_pred**(2-p) / (2-p)
136+
137+
Notes:
138+
- Here, we do not drop constant terms in contrast to the version with log-link.
139+
"""
140+
)
141+
120142
doc_HalfBinomialLoss = (
121143
"""Half Binomial deviance loss with logit link.
122144

@@ -151,6 +173,9 @@ class_list = [
151173
("CyHalfTweedieLoss", doc_HalfTweedieLoss, "power",
152174
"closs_half_tweedie", "closs_grad_half_tweedie",
153175
"cgradient_half_tweedie", "cgrad_hess_half_tweedie"),
176+
("CyHalfTweedieLossIdentity", doc_HalfTweedieLossIdentity, "power",
177+
"closs_half_tweedie_identity", "closs_grad_half_tweedie_identity",
178+
"cgradient_half_tweedie_identity", "cgrad_hess_half_tweedie_identity"),
154179
("CyHalfBinomialLoss", doc_HalfBinomialLoss, None,
155180
"closs_half_binomial", "closs_grad_half_binomial",
156181
"cgradient_half_binomial", "cgrad_hess_half_binomial"),
@@ -194,7 +219,7 @@ from cython.parallel import parallel, prange
194219
import numpy as np
195220
cimport numpy as np
196221

197-
from libc.math cimport exp, fabs, log, log1p
222+
from libc.math cimport exp, fabs, log, log1p, pow
198223
from libc.stdlib cimport malloc, free
199224

200225
np.import_array()
@@ -420,7 +445,7 @@ cdef inline double_pair cgrad_hess_half_gamma(
420445

421446

422447
# Half Tweedie Deviance with Log-Link, dropping constant terms
423-
# Note that by dropping constants this is no longer smooth in parameter power.
448+
# Note that by dropping constants this is no longer continuous in parameter power.
424449
cdef inline double closs_half_tweedie(
425450
double y_true,
426451
double raw_prediction,
@@ -501,6 +526,102 @@ cdef inline double_pair cgrad_hess_half_tweedie(
501526
return gh
502527

503528

529+
# Half Tweedie Deviance with identity link, without dropping constant terms!
530+
# Therefore, best loss value is zero.
531+
cdef inline double closs_half_tweedie_identity(
532+
double y_true,
533+
double raw_prediction,
534+
double power
535+
) nogil:
536+
cdef double tmp
537+
if power == 0.:
538+
return closs_half_squared_error(y_true, raw_prediction)
539+
elif power == 1.:
540+
if y_true == 0:
541+
return raw_prediction
542+
else:
543+
return y_true * log(y_true/raw_prediction) + raw_prediction - y_true
544+
elif power == 2.:
545+
return log(raw_prediction/y_true) + y_true/raw_prediction - 1.
546+
else:
547+
tmp = pow(raw_prediction, 1. - power)
548+
tmp = raw_prediction * tmp / (2. - power) - y_true * tmp / (1. - power)
549+
if y_true > 0:
550+
tmp += pow(y_true, 2. - power) / ((1. - power) * (2. - power))
551+
return tmp
552+
553+
554+
cdef inline double cgradient_half_tweedie_identity(
555+
double y_true,
556+
double raw_prediction,
557+
double power
558+
) nogil:
559+
if power == 0.:
560+
return raw_prediction - y_true
561+
elif power == 1.:
562+
return 1. - y_true / raw_prediction
563+
elif power == 2.:
564+
return (raw_prediction - y_true) / (raw_prediction * raw_prediction)
565+
else:
566+
return pow(raw_prediction, -power) * (raw_prediction - y_true)
567+
568+
569+
cdef inline double_pair closs_grad_half_tweedie_identity(
570+
double y_true,
571+
double raw_prediction,
572+
double power
573+
) nogil:
574+
cdef double_pair lg
575+
cdef double tmp
576+
if power == 0.:
577+
lg.val2 = raw_prediction - y_true # gradient
578+
lg.val1 = 0.5 * lg.val2 * lg.val2 # loss
579+
elif power == 1.:
580+
if y_true == 0:
581+
lg.val1 = raw_prediction
582+
else:
583+
lg.val1 = (y_true * log(y_true/raw_prediction) # loss
584+
+ raw_prediction - y_true)
585+
lg.val2 = 1. - y_true / raw_prediction # gradient
586+
elif power == 2.:
587+
lg.val1 = log(raw_prediction/y_true) + y_true/raw_prediction - 1. # loss
588+
tmp = raw_prediction * raw_prediction
589+
lg.val2 = (raw_prediction - y_true) / tmp # gradient
590+
else:
591+
tmp = pow(raw_prediction, 1. - power)
592+
lg.val1 = (raw_prediction * tmp / (2. - power) # loss
593+
- y_true * tmp / (1. - power))
594+
if y_true > 0:
595+
lg.val1 += (pow(y_true, 2. - power)
596+
/ ((1. - power) * (2. - power)))
597+
lg.val2 = tmp * (1. - y_true / raw_prediction) # gradient
598+
return lg
599+
600+
601+
cdef inline double_pair cgrad_hess_half_tweedie_identity(
602+
double y_true,
603+
double raw_prediction,
604+
double power
605+
) nogil:
606+
cdef double_pair gh
607+
cdef double tmp
608+
if power == 0.:
609+
gh.val1 = raw_prediction - y_true # gradient
610+
gh.val2 = 1. # hessian
611+
elif power == 1.:
612+
gh.val1 = 1. - y_true / raw_prediction # gradient
613+
gh.val2 = y_true / (raw_prediction * raw_prediction) # hessian
614+
elif power == 2.:
615+
tmp = raw_prediction * raw_prediction
616+
gh.val1 = (raw_prediction - y_true) / tmp # gradient
617+
gh.val2 = (-1. + 2. * y_true / raw_prediction) / tmp # hessian
618+
else:
619+
tmp = pow(raw_prediction, -power)
620+
gh.val1 = tmp * (raw_prediction - y_true) # gradient
621+
gh.val2 = tmp * ((1. - power) + power * y_true / raw_prediction) # hessian
622+
return gh
623+
624+
504625
# Half Binomial deviance with logit-link, aka log-loss or binary cross entropy
505626
cdef inline double closs_half_binomial(
506627
double y_true,

sklearn/_loss/glm_distribution.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
# Author: Christian Lorentzen <[email protected]>
66
# License: BSD 3 clause
7+
#
8+
# TODO(1.3): remove file
9+
# This is only used for backward compatibility in _GeneralizedLinearRegressor
10+
# for the deprecated family attribute.
711

812
from abc import ABCMeta, abstractmethod
913
from collections import namedtuple

sklearn/_loss/link.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Module contains classes for invertible (and differentiable) link functions.
33
"""
4-
# Author: Christian Lorentzen <lorentzen.ch@googlemail.com>
4+
# Author: Christian Lorentzen <lorentzen.ch@gmail.com>
55

66
from abc import ABC, abstractmethod
77
from dataclasses import dataclass
@@ -23,7 +23,7 @@ def __post_init__(self):
2323
"""Check that low <= high"""
2424
if self.low > self.high:
2525
raise ValueError(
26-
f"On must have low <= high; got low={self.low}, high={self.high}."
26+
f"One must have low <= high; got low={self.low}, high={self.high}."
2727
)
2828

2929
def includes(self, x):

sklearn/_loss/loss.py

+47
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
CyHalfPoissonLoss,
2626
CyHalfGammaLoss,
2727
CyHalfTweedieLoss,
28+
CyHalfTweedieLossIdentity,
2829
CyHalfBinomialLoss,
2930
CyHalfMultinomialLoss,
3031
)
@@ -770,6 +771,52 @@ def constant_to_optimal_zero(self, y_true, sample_weight=None):
770771
return term
771772

772773

774+
class HalfTweedieLossIdentity(BaseLoss):
775+
"""Half Tweedie deviance loss with identity link, for regression.
776+
777+
Domain:
778+
y_true in real numbers for power <= 0
779+
y_true in non-negative real numbers for 0 < power < 2
780+
y_true in positive real numbers for 2 <= power
781+
y_pred in positive real numbers for power != 0
782+
y_pred in real numbers for power = 0
783+
power in real numbers
784+
785+
Link:
786+
y_pred = raw_prediction
787+
788+
For a given sample x_i, half Tweedie deviance loss with p=power is defined
789+
as::
790+
791+
loss(x_i) = max(y_true_i, 0)**(2-p) / (1-p) / (2-p)
792+
- y_true_i * raw_prediction_i**(1-p) / (1-p)
793+
+ raw_prediction_i**(2-p) / (2-p)
794+
795+
Note that the minimum value of this loss is 0.
796+
797+
Note furthermore that although no Tweedie distribution exists for
798+
0 < power < 1, it still gives a strictly consistent scoring function for
799+
the expectation.
800+
"""
801+
802+
def __init__(self, sample_weight=None, power=1.5):
803+
super().__init__(
804+
closs=CyHalfTweedieLossIdentity(power=float(power)),
805+
link=IdentityLink(),
806+
)
807+
if self.closs.power <= 0:
808+
self.interval_y_true = Interval(-np.inf, np.inf, False, False)
809+
elif self.closs.power < 2:
810+
self.interval_y_true = Interval(0, np.inf, True, False)
811+
else:
812+
self.interval_y_true = Interval(0, np.inf, False, False)
813+
814+
if self.closs.power == 0:
815+
self.interval_y_pred = Interval(-np.inf, np.inf, False, False)
816+
else:
817+
self.interval_y_pred = Interval(0, np.inf, False, False)
818+
819+
773820
class HalfBinomialLoss(BaseLoss):
774821
"""Half Binomial deviance loss with logit link, for binary classification.
775822

sklearn/_loss/tests/test_glm_distribution.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Authors: Christian Lorentzen <[email protected]>
22
#
33
# License: BSD 3 clause
4+
#
5+
# TODO(1.3): remove file
46
import numpy as np
57
from numpy.testing import (
68
assert_allclose,

sklearn/_loss/tests/test_link.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def test_interval_raises():
1717
"""Test that interval with low > high raises ValueError."""
1818
with pytest.raises(
19-
ValueError, match="On must have low <= high; got low=1, high=0."
19+
ValueError, match="One must have low <= high; got low=1, high=0."
2020
):
2121
Interval(1, 0, False, False)
2222

0 commit comments

Comments
 (0)