diff --git a/.gitignore b/.gitignore
index 2c3dd0c4794c1..0bb730f493a56 100644
--- a/.gitignore
+++ b/.gitignore
@@ -76,6 +76,7 @@ _configtest.o.d
.mypy_cache/
# files generated from a template
+sklearn/_loss/_loss.pyx
sklearn/utils/_seq_dataset.pyx
sklearn/utils/_seq_dataset.pxd
sklearn/utils/_weight_vector.pyx
diff --git a/setup.cfg b/setup.cfg
index 9eca7fad87b4b..9c6f37f7dd8cd 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -69,6 +69,7 @@ allow_redefinition = True
[check-manifest]
# ignore files missing in VCS
ignore =
+ sklearn/_loss/_loss.pyx
sklearn/linear_model/_sag_fast.pyx
sklearn/utils/_seq_dataset.pyx
sklearn/utils/_seq_dataset.pxd
diff --git a/sklearn/_loss/__init__.py b/sklearn/_loss/__init__.py
index e69de29bb2d1d..14548c62231a2 100644
--- a/sklearn/_loss/__init__.py
+++ b/sklearn/_loss/__init__.py
@@ -0,0 +1,27 @@
+"""
+The :mod:`sklearn._loss` module includes loss function classes suitable for
+fitting classification and regression tasks.
+"""
+
+from .loss import (
+ HalfSquaredError,
+ AbsoluteError,
+ PinballLoss,
+ HalfPoissonLoss,
+ HalfGammaLoss,
+ HalfTweedieLoss,
+ HalfBinomialLoss,
+ HalfMultinomialLoss,
+)
+
+
+__all__ = [
+ "HalfSquaredError",
+ "AbsoluteError",
+ "PinballLoss",
+ "HalfPoissonLoss",
+ "HalfGammaLoss",
+ "HalfTweedieLoss",
+ "HalfBinomialLoss",
+ "HalfMultinomialLoss",
+]
diff --git a/sklearn/_loss/_loss.pxd b/sklearn/_loss/_loss.pxd
new file mode 100644
index 0000000000000..7255243d331dc
--- /dev/null
+++ b/sklearn/_loss/_loss.pxd
@@ -0,0 +1,75 @@
+# cython: language_level=3
+
+import numpy as np
+cimport numpy as np
+
+np.import_array()
+
+
+# Fused types for y_true, y_pred, raw_prediction
+ctypedef fused Y_DTYPE_C:
+ np.npy_float64
+ np.npy_float32
+
+
+# Fused types for gradient and hessian
+ctypedef fused G_DTYPE_C:
+ np.npy_float64
+ np.npy_float32
+
+
+# Struct to return 2 doubles
+ctypedef struct double_pair:
+ double val1
+ double val2
+
+
+# C base class for loss functions
+cdef class CyLossFunction:
+ cdef double cy_loss(self, double y_true, double raw_prediction) nogil
+ cdef double cy_gradient(self, double y_true, double raw_prediction) nogil
+ cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
+
+
+cdef class CyHalfSquaredError(CyLossFunction):
+ cdef double cy_loss(self, double y_true, double raw_prediction) nogil
+ cdef double cy_gradient(self, double y_true, double raw_prediction) nogil
+ cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
+
+
+cdef class CyAbsoluteError(CyLossFunction):
+ cdef double cy_loss(self, double y_true, double raw_prediction) nogil
+ cdef double cy_gradient(self, double y_true, double raw_prediction) nogil
+ cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
+
+
+cdef class CyPinballLoss(CyLossFunction):
+ cdef readonly double quantile # readonly makes it accessible from Python
+ cdef double cy_loss(self, double y_true, double raw_prediction) nogil
+ cdef double cy_gradient(self, double y_true, double raw_prediction) nogil
+ cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
+
+
+cdef class CyHalfPoissonLoss(CyLossFunction):
+ cdef double cy_loss(self, double y_true, double raw_prediction) nogil
+ cdef double cy_gradient(self, double y_true, double raw_prediction) nogil
+ cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
+
+
+cdef class CyHalfGammaLoss(CyLossFunction):
+ cdef double cy_loss(self, double y_true, double raw_prediction) nogil
+ cdef double cy_gradient(self, double y_true, double raw_prediction) nogil
+ cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
+
+
+cdef class CyHalfTweedieLoss(CyLossFunction):
+ cdef readonly double power # readonly makes it accessible from Python
+ cdef double cy_loss(self, double y_true, double raw_prediction) nogil
+ cdef double cy_gradient(self, double y_true, double raw_prediction) nogil
+ cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
+
+
+cdef class CyHalfBinomialLoss(CyLossFunction):
+ cdef double cy_loss(self, double y_true, double raw_prediction) nogil
+ cdef double cy_gradient(self, double y_true, double raw_prediction) nogil
+ cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil
diff --git a/sklearn/_loss/_loss.pyx.tp b/sklearn/_loss/_loss.pyx.tp
new file mode 100644
index 0000000000000..7c343c2881975
--- /dev/null
+++ b/sklearn/_loss/_loss.pyx.tp
@@ -0,0 +1,1211 @@
+{{py:
+
+"""
+Template file for easily generate loops over samples using Tempita
+(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py).
+
+Generated file: _loss.pyx
+
+Each loss class is generated by a cdef functions on single samples.
+The keywords between double braces are substituted in setup.py.
+"""
+
+doc_HalfSquaredError = (
+ """Half Squared Error with identity link.
+
+ Domain:
+ y_true and y_pred all real numbers
+
+ Link:
+ y_pred = raw_prediction
+ """
+)
+
+doc_AbsoluteError = (
+ """Absolute Error with identity link.
+
+ Domain:
+ y_true and y_pred all real numbers
+
+ Link:
+ y_pred = raw_prediction
+ """
+)
+
+doc_PinballLoss = (
+ """Quantile Loss aka Pinball Loss with identity link.
+
+ Domain:
+ y_true and y_pred all real numbers
+ quantile in (0, 1)
+
+ Link:
+ y_pred = raw_prediction
+
+ Note: 2 * cPinballLoss(quantile=0.5) equals cAbsoluteError()
+ """
+)
+
+doc_HalfPoissonLoss = (
+ """Half Poisson deviance loss with log-link.
+
+ Domain:
+ y_true in non-negative real numbers
+ y_pred in positive real numbers
+
+ Link:
+ y_pred = exp(raw_prediction)
+
+ Half Poisson deviance with log-link is
+ y_true * log(y_true/y_pred) + y_pred - y_true
+ = y_true * log(y_true) - y_true * raw_prediction
+ + exp(raw_prediction) - y_true
+
+ Dropping constant terms, this gives:
+ exp(raw_prediction) - y_true * raw_prediction
+ """
+)
+
+doc_HalfGammaLoss = (
+ """Half Gamma deviance loss with log-link.
+
+ Domain:
+ y_true and y_pred in positive real numbers
+
+ Link:
+ y_pred = exp(raw_prediction)
+
+ Half Gamma deviance with log-link is
+ log(y_pred/y_true) + y_true/y_pred - 1
+ = raw_prediction - log(y_true) + y_true * exp(-raw_prediction) - 1
+
+ Dropping constant terms, this gives:
+ raw_prediction + y_true * exp(-raw_prediction)
+ """
+)
+
+doc_HalfTweedieLoss = (
+ """Half Tweedie deviance loss with log-link.
+
+ Domain:
+ y_true in real numbers if p <= 0
+ y_true in non-negative real numbers if 0 < p < 2
+ y_true in positive real numbers if p >= 2
+ y_pred and power in positive real numbers
+
+ Link:
+ y_pred = exp(raw_prediction)
+
+ Half Tweedie deviance with log-link and p=power is
+ max(y_true, 0)**(2-p) / (1-p) / (2-p)
+ - y_true * y_pred**(1-p) / (1-p)
+ + y_pred**(2-p) / (2-p)
+ = max(y_true, 0)**(2-p) / (1-p) / (2-p)
+ - y_true * exp((1-p) * raw_prediction) / (1-p)
+ + exp((2-p) * raw_prediction) / (2-p)
+
+ Dropping constant terms, this gives:
+ exp((2-p) * raw_prediction) / (2-p)
+ - y_true * exp((1-p) * raw_prediction) / (1-p)
+
+ Notes:
+ - Poisson with p=1 and and Gamma with p=2 have different terms dropped such
+ that cHalfTweedieLoss is not continuous in p=power at p=1 and p=2.
+ - While the Tweedie distribution only exists for p<=0 or p>=1, the range
+ 0
np.empty().
+#
+# Note: We require 1-dim ndarrays to be contiguous.
+# TODO: Use const memoryviews with fused types with Cython 3.0 where
+# appropriate (arguments marked by "# IN").
+
+cimport cython
+from cython.parallel import parallel, prange
+import numpy as np
+cimport numpy as np
+
+from libc.math cimport exp, fabs, log, log1p
+from libc.stdlib cimport malloc, free
+
+np.import_array()
+
+
+# -------------------------------------
+# Helper functions
+# -------------------------------------
+# Numerically stable version of log(1 + exp(x)) for double precision
+# See https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
+cdef inline double log1pexp(double x) nogil:
+ if x <= -37:
+ return exp(x)
+ elif x <= 18:
+ return log1p(exp(x))
+ elif x <= 33.3:
+ return x + exp(-x)
+ else:
+ return x
+
+
+cdef inline void sum_exp_minus_max(
+ const int i,
+ Y_DTYPE_C[:, :] raw_prediction, # IN
+ Y_DTYPE_C *p # OUT
+) nogil:
+ # Thread local buffers are used to stores results of this function via p.
+ # The results are stored as follows:
+ # p[k] = exp(raw_prediction_i_k - max_value) for k = 0 to n_classes-1
+ # p[-2] = max(raw_prediction_i_k, k = 0 to n_classes-1)
+ # p[-1] = sum(p[k], k = 0 to n_classes-1) = sum of exponentials
+ # len(p) must be n_classes + 2
+ # Notes:
+ # - Using "by reference" arguments doesn't work well, therefore we use a
+ # longer p, see https://github.com/cython/cython/issues/1863
+ # - i needs to be passed (and stays constant) because otherwise Cython does
+ # not generate optimal code, see
+ # https://github.com/scikit-learn/scikit-learn/issues/17299
+ # - We do not normalize p by calculating p[k] = p[k] / sum_exps.
+ # This helps to save one loop over k.
+ cdef:
+ int k
+ int n_classes = raw_prediction.shape[1]
+ double max_value = raw_prediction[i, 0]
+ double sum_exps = 0
+ for k in range(1, n_classes):
+ # Compute max value of array for numerical stability
+ if max_value < raw_prediction[i, k]:
+ max_value = raw_prediction[i, k]
+
+ for k in range(n_classes):
+ p[k] = exp(raw_prediction[i, k] - max_value)
+ sum_exps += p[k]
+
+ p[n_classes] = max_value # same as p[-2]
+ p[n_classes + 1] = sum_exps # same as p[-1]
+
+
+# -------------------------------------
+# Single point inline C functions
+# -------------------------------------
+# Half Squared Error
+cdef inline double closs_half_squared_error(
+ double y_true,
+ double raw_prediction
+) nogil:
+ return 0.5 * (raw_prediction - y_true) * (raw_prediction - y_true)
+
+
+cdef inline double cgradient_half_squared_error(
+ double y_true,
+ double raw_prediction
+) nogil:
+ return raw_prediction - y_true
+
+
+cdef inline double_pair cgrad_hess_half_squared_error(
+ double y_true,
+ double raw_prediction
+) nogil:
+ cdef double_pair gh
+ gh.val1 = raw_prediction - y_true # gradient
+ gh.val2 = 1. # hessian
+ return gh
+
+
+# Absolute Error
+cdef inline double closs_absolute_error(
+ double y_true,
+ double raw_prediction
+) nogil:
+ return fabs(raw_prediction - y_true)
+
+
+cdef inline double cgradient_absolute_error(
+ double y_true,
+ double raw_prediction
+) nogil:
+ return 1. if raw_prediction > y_true else -1.
+
+
+cdef inline double_pair cgrad_hess_absolute_error(
+ double y_true,
+ double raw_prediction
+) nogil:
+ cdef double_pair gh
+ # Note that exact hessian = 0 almost everywhere. Optimization routines like
+ # in HGBT, however, need a hessian > 0. Therefore, we assign 1.
+ gh.val1 = 1. if raw_prediction > y_true else -1. # gradient
+ gh.val2 = 1. # hessian
+ return gh
+
+
+# Quantile Loss / Pinball Loss
+cdef inline double closs_pinball_loss(
+ double y_true,
+ double raw_prediction,
+ double quantile
+) nogil:
+ return (quantile * (y_true - raw_prediction) if y_true >= raw_prediction
+ else (1. - quantile) * (raw_prediction - y_true))
+
+
+cdef inline double cgradient_pinball_loss(
+ double y_true,
+ double raw_prediction,
+ double quantile
+) nogil:
+ return -quantile if y_true >=raw_prediction else 1. - quantile
+
+
+cdef inline double_pair cgrad_hess_pinball_loss(
+ double y_true,
+ double raw_prediction,
+ double quantile
+) nogil:
+ cdef double_pair gh
+ # Note that exact hessian = 0 almost everywhere. Optimization routines like
+ # in HGBT, however, need a hessian > 0. Therefore, we assign 1.
+ gh.val1 = -quantile if y_true >=raw_prediction else 1. - quantile # gradient
+ gh.val2 = 1. # hessian
+ return gh
+
+
+# Half Poisson Deviance with Log-Link, dropping constant terms
+cdef inline double closs_half_poisson(
+ double y_true,
+ double raw_prediction
+) nogil:
+ return exp(raw_prediction) - y_true * raw_prediction
+
+
+cdef inline double cgradient_half_poisson(
+ double y_true,
+ double raw_prediction
+) nogil:
+ # y_pred - y_true
+ return exp(raw_prediction) - y_true
+
+
+cdef inline double_pair closs_grad_half_poisson(
+ double y_true,
+ double raw_prediction
+) nogil:
+ cdef double_pair lg
+ lg.val2 = exp(raw_prediction) # used as temporary
+ lg.val1 = lg.val2 - y_true * raw_prediction # loss
+ lg.val2 -= y_true # gradient
+ return lg
+
+
+cdef inline double_pair cgrad_hess_half_poisson(
+ double y_true,
+ double raw_prediction
+) nogil:
+ cdef double_pair gh
+ gh.val2 = exp(raw_prediction) # hessian
+ gh.val1 = gh.val2 - y_true # gradient
+ return gh
+
+
+# Half Gamma Deviance with Log-Link, dropping constant terms
+cdef inline double closs_half_gamma(
+ double y_true,
+ double raw_prediction
+) nogil:
+ return raw_prediction + y_true * exp(-raw_prediction)
+
+
+cdef inline double cgradient_half_gamma(
+ double y_true,
+ double raw_prediction
+) nogil:
+ return 1. - y_true * exp(-raw_prediction)
+
+
+cdef inline double_pair closs_grad_half_gamma(
+ double y_true,
+ double raw_prediction
+) nogil:
+ cdef double_pair lg
+ lg.val2 = exp(-raw_prediction) # used as temporary
+ lg.val1 = raw_prediction + y_true * lg.val2 # loss
+ lg.val2 = 1. - y_true * lg.val2 # gradient
+ return lg
+
+
+cdef inline double_pair cgrad_hess_half_gamma(
+ double y_true,
+ double raw_prediction
+) nogil:
+ cdef double_pair gh
+ gh.val2 = exp(-raw_prediction) # used as temporary
+ gh.val1 = 1. - y_true * gh.val2 # gradient
+ gh.val2 *= y_true # hessian
+ return gh
+
+
+# Half Tweedie Deviance with Log-Link, dropping constant terms
+# Note that by dropping constants this is no longer smooth in parameter power.
+cdef inline double closs_half_tweedie(
+ double y_true,
+ double raw_prediction,
+ double power
+) nogil:
+ if power == 0.:
+ return closs_half_squared_error(y_true, exp(raw_prediction))
+ elif power == 1.:
+ return closs_half_poisson(y_true, raw_prediction)
+ elif power == 2.:
+ return closs_half_gamma(y_true, raw_prediction)
+ else:
+ return (exp((2. - power) * raw_prediction) / (2. - power)
+ - y_true * exp((1. - power) * raw_prediction) / (1. - power))
+
+
+cdef inline double cgradient_half_tweedie(
+ double y_true,
+ double raw_prediction,
+ double power
+) nogil:
+ cdef double exp1
+ if power == 0.:
+ exp1 = exp(raw_prediction)
+ return exp1 * (exp1 - y_true)
+ elif power == 1.:
+ return cgradient_half_poisson(y_true, raw_prediction)
+ elif power == 2.:
+ return cgradient_half_gamma(y_true, raw_prediction)
+ else:
+ return (exp((2. - power) * raw_prediction)
+ - y_true * exp((1. - power) * raw_prediction))
+
+
+cdef inline double_pair closs_grad_half_tweedie(
+ double y_true,
+ double raw_prediction,
+ double power
+) nogil:
+ cdef double_pair lg
+ cdef double exp1, exp2
+ if power == 0.:
+ exp1 = exp(raw_prediction)
+ lg.val1 = closs_half_squared_error(y_true, exp1) # loss
+ lg.val2 = exp1 * (exp1 - y_true) # gradient
+ elif power == 1.:
+ return closs_grad_half_poisson(y_true, raw_prediction)
+ elif power == 2.:
+ return closs_grad_half_gamma(y_true, raw_prediction)
+ else:
+ exp1 = exp((1. - power) * raw_prediction)
+ exp2 = exp((2. - power) * raw_prediction)
+ lg.val1 = exp2 / (2. - power) - y_true * exp1 / (1. - power) # loss
+ lg.val2 = exp2 - y_true * exp1 # gradient
+ return lg
+
+
+cdef inline double_pair cgrad_hess_half_tweedie(
+ double y_true,
+ double raw_prediction,
+ double power
+) nogil:
+ cdef double_pair gh
+ cdef double exp1, exp2
+ if power == 0.:
+ exp1 = exp(raw_prediction)
+ gh.val1 = exp1 * (exp1 - y_true) # gradient
+ gh.val2 = exp1 * (2 * exp1 - y_true) # hessian
+ elif power == 1.:
+ return cgrad_hess_half_poisson(y_true, raw_prediction)
+ elif power == 2.:
+ return cgrad_hess_half_gamma(y_true, raw_prediction)
+ else:
+ exp1 = exp((1. - power) * raw_prediction)
+ exp2 = exp((2. - power) * raw_prediction)
+ gh.val1 = exp2 - y_true * exp1 # gradient
+ gh.val2 = (2. - power) * exp2 - (1. - power) * y_true * exp1 # hessian
+ return gh
+
+
+# Half Binomial deviance with logit-link, aka log-loss or binary cross entropy
+cdef inline double closs_half_binomial(
+ double y_true,
+ double raw_prediction
+) nogil:
+ # log1p(exp(raw_prediction)) - y_true * raw_prediction
+ return log1pexp(raw_prediction) - y_true * raw_prediction
+
+
+cdef inline double cgradient_half_binomial(
+ double y_true,
+ double raw_prediction
+) nogil:
+ # y_pred - y_true = expit(raw_prediction) - y_true
+ # Numerically more stable, see
+ # http://fa.bianp.net/blog/2019/evaluate_logistic/
+ # if raw_prediction < 0:
+ # exp_tmp = exp(raw_prediction)
+ # return ((1 - y_true) * exp_tmp - y_true) / (1 + exp_tmp)
+ # else:
+ # exp_tmp = exp(-raw_prediction)
+ # return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)
+ # Note that optimal speed would be achieved, at the cost of precision, by
+ # return expit(raw_prediction) - y_true
+ # i.e. no "if else" and an own inline implemention of expit instead of
+ # from scipy.special.cython_special cimport expit
+ # The case distinction raw_prediction < 0 in the stable implementation
+ # does not provide significant better precision. Therefore we go without
+ # it.
+ cdef double exp_tmp
+ exp_tmp = exp(-raw_prediction)
+ return ((1 - y_true) - y_true * exp_tmp) / (1 + exp_tmp)
+
+
+cdef inline double_pair closs_grad_half_binomial(
+ double y_true,
+ double raw_prediction
+) nogil:
+ cdef double_pair lg
+ if raw_prediction <= 0:
+ lg.val2 = exp(raw_prediction) # used as temporary
+ if raw_prediction <= -37:
+ lg.val1 = lg.val2 - y_true * raw_prediction # loss
+ else:
+ lg.val1 = log1p(lg.val2) - y_true * raw_prediction # loss
+ lg.val2 = ((1 - y_true) * lg.val2 - y_true) / (1 + lg.val2) # gradient
+ else:
+ lg.val2 = exp(-raw_prediction) # used as temporary
+ if raw_prediction <= 18:
+ # log1p(exp(x)) = log(1 + exp(x)) = x + log1p(exp(-x))
+ lg.val1 = log1p(lg.val2) + (1 - y_true) * raw_prediction # loss
+ else:
+ lg.val1 = lg.val2 + (1 - y_true) * raw_prediction # loss
+ lg.val2 = ((1 - y_true) - y_true * lg.val2) / (1 + lg.val2) # gradient
+ return lg
+
+
+cdef inline double_pair cgrad_hess_half_binomial(
+ double y_true,
+ double raw_prediction
+) nogil:
+ # with y_pred = expit(raw)
+ # hessian = y_pred * (1 - y_pred) = exp(raw) / (1 + exp(raw))**2
+ # = exp(-raw) / (1 + exp(-raw))**2
+ cdef double_pair gh
+ gh.val2 = exp(-raw_prediction) # used as temporary
+ gh.val1 = ((1 - y_true) - y_true * gh.val2) / (1 + gh.val2) # gradient
+ gh.val2 = gh.val2 / (1 + gh.val2)**2 # hessian
+ return gh
+
+
+# ---------------------------------------------------
+# Extension Types for Loss Functions of 1-dim targets
+# ---------------------------------------------------
+cdef class CyLossFunction:
+ """Base class for convex loss functions."""
+
+ cdef double cy_loss(self, double y_true, double raw_prediction) nogil:
+ """Compute the loss for a single sample.
+
+ Parameters
+ ----------
+ y_true : double
+ Observed, true target value.
+ raw_prediction : double
+ Raw prediction value (in link space).
+
+ Returns
+ -------
+ double
+ The loss evaluated at `y_true` and `raw_prediction`.
+ """
+ pass
+
+ cdef double cy_gradient(self, double y_true, double raw_prediction) nogil:
+ """Compute gradient of loss w.r.t. raw_prediction for a single sample.
+
+ Parameters
+ ----------
+ y_true : double
+ Observed, true target value.
+ raw_prediction : double
+ Raw prediction value (in link space).
+
+ Returns
+ -------
+ double
+ The derivative of the loss function w.r.t. `raw_prediction`.
+ """
+ pass
+
+ cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil:
+ """Compute gradient and hessian.
+
+ Gradient and hessian of loss w.r.t. raw_prediction for a single sample.
+
+ This is usually diagonal in raw_prediction_i and raw_prediction_j.
+ Therefore, we return the diagonal element i=j.
+
+ For a loss with a non-canonical link, this might implement the diagonal
+ of the Fisher matrix (=expected hessian) instead of the hessian.
+
+ Parameters
+ ----------
+ y_true : double
+ Observed, true target value.
+ raw_prediction : double
+ Raw prediction value (in link space).
+
+ Returns
+ -------
+ double_pair
+ Gradient and hessian of the loss function w.r.t. `raw_prediction`.
+ """
+ pass
+
+ # Note: With Cython 3.0, fused types can be used together with const:
+ # const Y_DTYPE_C double[::1] y_true
+ # See release notes 3.0.0 alpha1
+ # https://cython.readthedocs.io/en/latest/src/changes.html#alpha-1-2020-04-12
+ def loss(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[::1] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] loss_out, # OUT
+ int n_threads=1
+ ):
+ """Compute the pointwise loss value for each input.
+
+ Parameters
+ ----------
+ y_true : array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : array of shape (n_samples,)
+ Raw prediction values (in link space).
+ sample_weight : array of shape (n_samples,) or None
+ Sample weights.
+ loss_out : array of shape (n_samples,)
+ A location into which the result is stored.
+ n_threads : int
+ Number of threads used by OpenMP (if any).
+
+ Returns
+ -------
+ loss : array of shape (n_samples,)
+ Element-wise loss function.
+ """
+ pass
+
+ def gradient(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[::1] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] gradient_out, # OUT
+ int n_threads=1
+ ):
+ """Compute gradient of loss w.r.t raw_prediction for each input.
+
+ Parameters
+ ----------
+ y_true : array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : array of shape (n_samples,)
+ Raw prediction values (in link space).
+ sample_weight : array of shape (n_samples,) or None
+ Sample weights.
+ gradient_out : array of shape (n_samples,)
+ A location into which the result is stored.
+ n_threads : int
+ Number of threads used by OpenMP (if any).
+
+ Returns
+ -------
+ gradient : array of shape (n_samples,)
+ Element-wise gradients.
+ """
+ pass
+
+ def loss_gradient(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[::1] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] loss_out, # OUT
+ G_DTYPE_C[::1] gradient_out, # OUT
+ int n_threads=1
+ ):
+ """Compute loss and gradient of loss w.r.t raw_prediction.
+
+ Parameters
+ ----------
+ y_true : array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : array of shape (n_samples,)
+ Raw prediction values (in link space).
+ sample_weight : array of shape (n_samples,) or None
+ Sample weights.
+ loss_out : array of shape (n_samples,) or None
+ A location into which the element-wise loss is stored.
+ gradient_out : array of shape (n_samples,)
+ A location into which the gradient is stored.
+ n_threads : int
+ Number of threads used by OpenMP (if any).
+
+ Returns
+ -------
+ loss : array of shape (n_samples,)
+ Element-wise loss function.
+
+ gradient : array of shape (n_samples,)
+ Element-wise gradients.
+ """
+ self.loss(y_true, raw_prediction, sample_weight, loss_out, n_threads)
+ self.gradient(y_true, raw_prediction, sample_weight, gradient_out, n_threads)
+ return np.asarray(loss_out), np.asarray(gradient_out)
+
+ def gradient_hessian(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[::1] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] gradient_out, # OUT
+ G_DTYPE_C[::1] hessian_out, # OUT
+ int n_threads=1
+ ):
+ """Compute gradient and hessian of loss w.r.t raw_prediction.
+
+ Parameters
+ ----------
+ y_true : array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : array of shape (n_samples,)
+ Raw prediction values (in link space).
+ sample_weight : array of shape (n_samples,) or None
+ Sample weights.
+ gradient_out : array of shape (n_samples,)
+ A location into which the gradient is stored.
+ hessian_out : array of shape (n_samples,)
+ A location into which the hessian is stored.
+ n_threads : int
+ Number of threads used by OpenMP (if any).
+
+ Returns
+ -------
+ gradient : array of shape (n_samples,)
+ Element-wise gradients.
+
+ hessian : array of shape (n_samples,)
+ Element-wise hessians.
+ """
+ pass
+
+
+{{for name, docstring, param, closs, closs_grad, cgrad, cgrad_hess, in class_list}}
+{{py:
+if param is None:
+ with_param = ""
+else:
+ with_param = ", self." + param
+}}
+
+cdef class {{name}}(CyLossFunction):
+ """{{docstring}}"""
+
+ {{if param is not None}}
+ def __init__(self, {{param}}):
+ self.{{param}} = {{param}}
+ {{endif}}
+
+ cdef inline double cy_loss(self, double y_true, double raw_prediction) nogil:
+ return {{closs}}(y_true, raw_prediction{{with_param}})
+
+ cdef inline double cy_gradient(self, double y_true, double raw_prediction) nogil:
+ return {{cgrad}}(y_true, raw_prediction{{with_param}})
+
+ cdef inline double_pair cy_grad_hess(self, double y_true, double raw_prediction) nogil:
+ return {{cgrad_hess}}(y_true, raw_prediction{{with_param}})
+
+ def loss(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[::1] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] loss_out, # OUT
+ int n_threads=1
+ ):
+ cdef:
+ int i
+ int n_samples = y_true.shape[0]
+
+ if sample_weight is None:
+ for i in prange(
+ n_samples, schedule='static', nogil=True, num_threads=n_threads
+ ):
+ loss_out[i] = {{closs}}(y_true[i], raw_prediction[i]{{with_param}})
+ else:
+ for i in prange(
+ n_samples, schedule='static', nogil=True, num_threads=n_threads
+ ):
+ loss_out[i] = sample_weight[i] * {{closs}}(y_true[i], raw_prediction[i]{{with_param}})
+
+ return np.asarray(loss_out)
+
+ {{if closs_grad is not None}}
+ def loss_gradient(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[::1] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] loss_out, # OUT
+ G_DTYPE_C[::1] gradient_out, # OUT
+ int n_threads=1
+ ):
+ cdef:
+ int i
+ int n_samples = y_true.shape[0]
+ double_pair dbl2
+
+ if sample_weight is None:
+ for i in prange(
+ n_samples, schedule='static', nogil=True, num_threads=n_threads
+ ):
+ dbl2 = {{closs_grad}}(y_true[i], raw_prediction[i]{{with_param}})
+ loss_out[i] = dbl2.val1
+ gradient_out[i] = dbl2.val2
+ else:
+ for i in prange(
+ n_samples, schedule='static', nogil=True, num_threads=n_threads
+ ):
+ dbl2 = {{closs_grad}}(y_true[i], raw_prediction[i]{{with_param}})
+ loss_out[i] = sample_weight[i] * dbl2.val1
+ gradient_out[i] = sample_weight[i] * dbl2.val2
+
+ return np.asarray(loss_out), np.asarray(gradient_out)
+ {{endif}}
+
+ def gradient(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[::1] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] gradient_out, # OUT
+ int n_threads=1
+ ):
+ cdef:
+ int i
+ int n_samples = y_true.shape[0]
+
+ if sample_weight is None:
+ for i in prange(
+ n_samples, schedule='static', nogil=True, num_threads=n_threads
+ ):
+ gradient_out[i] = {{cgrad}}(y_true[i], raw_prediction[i]{{with_param}})
+ else:
+ for i in prange(
+ n_samples, schedule='static', nogil=True, num_threads=n_threads
+ ):
+ gradient_out[i] = sample_weight[i] * {{cgrad}}(y_true[i], raw_prediction[i]{{with_param}})
+
+ return np.asarray(gradient_out)
+
+ def gradient_hessian(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[::1] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] gradient_out, # OUT
+ G_DTYPE_C[::1] hessian_out, # OUT
+ int n_threads=1
+ ):
+ cdef:
+ int i
+ int n_samples = y_true.shape[0]
+ double_pair dbl2
+
+ if sample_weight is None:
+ for i in prange(
+ n_samples, schedule='static', nogil=True, num_threads=n_threads
+ ):
+ dbl2 = {{cgrad_hess}}(y_true[i], raw_prediction[i]{{with_param}})
+ gradient_out[i] = dbl2.val1
+ hessian_out[i] = dbl2.val2
+ else:
+ for i in prange(
+ n_samples, schedule='static', nogil=True, num_threads=n_threads
+ ):
+ dbl2 = {{cgrad_hess}}(y_true[i], raw_prediction[i]{{with_param}})
+ gradient_out[i] = sample_weight[i] * dbl2.val1
+ hessian_out[i] = sample_weight[i] * dbl2.val2
+
+ return np.asarray(gradient_out), np.asarray(hessian_out)
+
+{{endfor}}
+
+
+# The multinomial deviance loss is also known as categorical cross-entropy or
+# multinomial log-likelihood
+cdef class CyHalfMultinomialLoss(CyLossFunction):
+ """Half Multinomial deviance loss with multinomial logit link.
+
+ Domain:
+ y_true in {0, 1, 2, 3, .., n_classes - 1}
+ y_pred in (0, 1)**n_classes, i.e. interval with boundaries excluded
+
+ Link:
+ y_pred = softmax(raw_prediction)
+
+ Note: Label encoding is built-in, i.e. {0, 1, 2, 3, .., n_classes - 1} is
+ mapped to (y_true == k) for k = 0 .. n_classes - 1 which is either 0 or 1.
+ """
+
+ # Note that we do not assume memory alignement/contiguity of 2d arrays.
+ # There seems to be little benefit in doing so. Benchmarks proofing the
+ # opposite are welcome.
+ def loss(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[:, :] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] loss_out, # OUT
+ int n_threads=1
+ ):
+ cdef:
+ int i, k
+ int n_samples = y_true.shape[0]
+ int n_classes = raw_prediction.shape[1]
+ Y_DTYPE_C max_value, sum_exps
+ Y_DTYPE_C* p # temporary buffer
+
+ # We assume n_samples > n_classes. In this case having the inner loop
+ # over n_classes is a good default.
+ # TODO: If every memoryview is contiguous and raw_prediction is
+ # f-contiguous, can we write a better algo (loops) to improve
+ # performance?
+ if sample_weight is None:
+ # inner loop over n_classes
+ with nogil, parallel(num_threads=n_threads):
+ # Define private buffer variables as each thread might use its
+ # own.
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ max_value = p[n_classes] # p[-2]
+ sum_exps = p[n_classes + 1] # p[-1]
+ loss_out[i] = log(sum_exps) + max_value
+
+ for k in range(n_classes):
+ # label decode y_true
+ if y_true[i] == k:
+ loss_out[i] -= raw_prediction[i, k]
+
+ free(p)
+ else:
+ with nogil, parallel(num_threads=n_threads):
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ max_value = p[n_classes] # p[-2]
+ sum_exps = p[n_classes + 1] # p[-1]
+ loss_out[i] = log(sum_exps) + max_value
+
+ for k in range(n_classes):
+ # label decode y_true
+ if y_true[i] == k:
+ loss_out[i] -= raw_prediction[i, k]
+
+ loss_out[i] *= sample_weight[i]
+
+ free(p)
+
+ return np.asarray(loss_out)
+
+ def loss_gradient(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[:, :] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[::1] loss_out, # OUT
+ G_DTYPE_C[:, :] gradient_out, # OUT
+ int n_threads=1
+ ):
+ cdef:
+ int i, k
+ int n_samples = y_true.shape[0]
+ int n_classes = raw_prediction.shape[1]
+ Y_DTYPE_C max_value, sum_exps
+ Y_DTYPE_C* p # temporary buffer
+
+ if sample_weight is None:
+ # inner loop over n_classes
+ with nogil, parallel(num_threads=n_threads):
+ # Define private buffer variables as each thread might use its
+ # own.
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ max_value = p[n_classes] # p[-2]
+ sum_exps = p[n_classes + 1] # p[-1]
+ loss_out[i] = log(sum_exps) + max_value
+
+ for k in range(n_classes):
+ # label decode y_true
+ if y_true [i] == k:
+ loss_out[i] -= raw_prediction[i, k]
+ p[k] /= sum_exps # p_k = y_pred_k = prob of class k
+ # gradient_k = p_k - (y_true == k)
+ gradient_out[i, k] = p[k] - (y_true[i] == k)
+
+ free(p)
+ else:
+ with nogil, parallel(num_threads=n_threads):
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ max_value = p[n_classes] # p[-2]
+ sum_exps = p[n_classes + 1] # p[-1]
+ loss_out[i] = log(sum_exps) + max_value
+
+ for k in range(n_classes):
+ # label decode y_true
+ if y_true [i] == k:
+ loss_out[i] -= raw_prediction[i, k]
+ p[k] /= sum_exps # p_k = y_pred_k = prob of class k
+ # gradient_k = (p_k - (y_true == k)) * sw
+ gradient_out[i, k] = (p[k] - (y_true[i] == k)) * sample_weight[i]
+
+ loss_out[i] *= sample_weight[i]
+
+ free(p)
+
+ return np.asarray(loss_out), np.asarray(gradient_out)
+
+ def gradient(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[:, :] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[:, :] gradient_out, # OUT
+ int n_threads=1
+ ):
+ cdef:
+ int i, k
+ int n_samples = y_true.shape[0]
+ int n_classes = raw_prediction.shape[1]
+ Y_DTYPE_C sum_exps
+ Y_DTYPE_C* p # temporary buffer
+
+ if sample_weight is None:
+ # inner loop over n_classes
+ with nogil, parallel(num_threads=n_threads):
+ # Define private buffer variables as each thread might use its
+ # own.
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ sum_exps = p[n_classes + 1] # p[-1]
+
+ for k in range(n_classes):
+ p[k] /= sum_exps # p_k = y_pred_k = prob of class k
+ # gradient_k = y_pred_k - (y_true == k)
+ gradient_out[i, k] = p[k] - (y_true[i] == k)
+
+ free(p)
+ else:
+ with nogil, parallel(num_threads=n_threads):
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ sum_exps = p[n_classes + 1] # p[-1]
+
+ for k in range(n_classes):
+ p[k] /= sum_exps # p_k = y_pred_k = prob of class k
+ # gradient_k = (p_k - (y_true == k)) * sw
+ gradient_out[i, k] = (p[k] - (y_true[i] == k)) * sample_weight[i]
+
+ free(p)
+
+ return np.asarray(gradient_out)
+
+ def gradient_hessian(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[:, :] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[:, :] gradient_out, # OUT
+ G_DTYPE_C[:, :] hessian_out, # OUT
+ int n_threads=1
+ ):
+ cdef:
+ int i, k
+ int n_samples = y_true.shape[0]
+ int n_classes = raw_prediction.shape[1]
+ Y_DTYPE_C sum_exps
+ Y_DTYPE_C* p # temporary buffer
+
+ if sample_weight is None:
+ # inner loop over n_classes
+ with nogil, parallel(num_threads=n_threads):
+ # Define private buffer variables as each thread might use its
+ # own.
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ sum_exps = p[n_classes + 1] # p[-1]
+
+ for k in range(n_classes):
+ p[k] /= sum_exps # p_k = y_pred_k = prob of class k
+ # hessian_k = p_k * (1 - p_k)
+ # gradient_k = p_k - (y_true == k)
+ gradient_out[i, k] = p[k] - (y_true[i] == k)
+ hessian_out[i, k] = p[k] * (1. - p[k])
+
+ free(p)
+ else:
+ with nogil, parallel(num_threads=n_threads):
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ sum_exps = p[n_classes + 1] # p[-1]
+
+ for k in range(n_classes):
+ p[k] /= sum_exps # p_k = y_pred_k = prob of class k
+ # gradient_k = (p_k - (y_true == k)) * sw
+ # hessian_k = p_k * (1 - p_k) * sw
+ gradient_out[i, k] = (p[k] - (y_true[i] == k)) * sample_weight[i]
+ hessian_out[i, k] = (p[k] * (1. - p[k])) * sample_weight[i]
+
+ free(p)
+
+ return np.asarray(gradient_out), np.asarray(hessian_out)
+
+
+ # This method simplifies the implementation of hessp in linear models,
+ # i.e. the matrix-vector product of the full hessian, not only of the
+ # diagonal (in the classes) approximation as implemented above.
+ def gradient_proba(
+ self,
+ Y_DTYPE_C[::1] y_true, # IN
+ Y_DTYPE_C[:, :] raw_prediction, # IN
+ Y_DTYPE_C[::1] sample_weight, # IN
+ G_DTYPE_C[:, :] gradient_out, # OUT
+ G_DTYPE_C[:, :] proba_out, # OUT
+ int n_threads=1
+ ):
+ cdef:
+ int i, k
+ int n_samples = y_true.shape[0]
+ int n_classes = raw_prediction.shape[1]
+ Y_DTYPE_C sum_exps
+ Y_DTYPE_C* p # temporary buffer
+
+ if sample_weight is None:
+ # inner loop over n_classes
+ with nogil, parallel(num_threads=n_threads):
+ # Define private buffer variables as each thread might use its
+ # own.
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ sum_exps = p[n_classes + 1] # p[-1]
+
+ for k in range(n_classes):
+ proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k
+ # gradient_k = y_pred_k - (y_true == k)
+ gradient_out[i, k] = proba_out[i, k] - (y_true[i] == k)
+
+ free(p)
+ else:
+ with nogil, parallel(num_threads=n_threads):
+ p = malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
+
+ for i in prange(n_samples, schedule='static'):
+ sum_exp_minus_max(i, raw_prediction, p)
+ sum_exps = p[n_classes + 1] # p[-1]
+
+ for k in range(n_classes):
+ proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k
+ # gradient_k = (p_k - (y_true == k)) * sw
+ gradient_out[i, k] = (proba_out[i, k] - (y_true[i] == k)) * sample_weight[i]
+
+ free(p)
+
+ return np.asarray(gradient_out), np.asarray(proba_out)
diff --git a/sklearn/_loss/link.py b/sklearn/_loss/link.py
new file mode 100644
index 0000000000000..18ad5901d1f3c
--- /dev/null
+++ b/sklearn/_loss/link.py
@@ -0,0 +1,261 @@
+"""
+Module contains classes for invertible (and differentiable) link functions.
+"""
+# Author: Christian Lorentzen
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+import numpy as np
+from scipy.special import expit, logit
+from scipy.stats import gmean
+from ..utils.extmath import softmax
+
+
+@dataclass
+class Interval:
+ low: float
+ high: float
+ low_inclusive: bool
+ high_inclusive: bool
+
+ def __post_init__(self):
+ """Check that low <= high"""
+ if self.low > self.high:
+ raise ValueError(
+ f"On must have low <= high; got low={self.low}, high={self.high}."
+ )
+
+ def includes(self, x):
+ """Test whether all values of x are in interval range.
+
+ Parameters
+ ----------
+ x : ndarray
+ Array whose elements are tested to be in interval range.
+
+ Returns
+ -------
+ result : bool
+ """
+ if self.low_inclusive:
+ low = np.greater_equal(x, self.low)
+ else:
+ low = np.greater(x, self.low)
+
+ if not np.all(low):
+ return False
+
+ if self.high_inclusive:
+ high = np.less_equal(x, self.high)
+ else:
+ high = np.less(x, self.high)
+
+ # Note: np.all returns numpy.bool_
+ return bool(np.all(high))
+
+
+def _inclusive_low_high(interval, dtype=np.float64):
+ """Generate values low and high to be within the interval range.
+
+ This is used in tests only.
+
+ Returns
+ -------
+ low, high : tuple
+ The returned values low and high lie within the interval.
+ """
+ eps = 10 * np.finfo(dtype).eps
+ if interval.low == -np.inf:
+ low = -1e10
+ elif interval.low < 0:
+ low = interval.low * (1 - eps) + eps
+ else:
+ low = interval.low * (1 + eps) + eps
+
+ if interval.high == np.inf:
+ high = 1e10
+ elif interval.high < 0:
+ high = interval.high * (1 + eps) - eps
+ else:
+ high = interval.high * (1 - eps) - eps
+
+ return low, high
+
+
+class BaseLink(ABC):
+ """Abstract base class for differentiable, invertible link functions.
+
+ Convention:
+ - link function g: raw_prediction = g(y_pred)
+ - inverse link h: y_pred = h(raw_prediction)
+
+ For (generalized) linear models, `raw_prediction = X @ coef` is the so
+ called linear predictor, and `y_pred = h(raw_prediction)` is the predicted
+ conditional (on X) expected value of the target `y_true`.
+
+ The methods are not implemented as staticmethods in case a link function needs
+ parameters.
+ """
+
+ is_multiclass = False # used for testing only
+
+ # Usually, raw_prediction may be any real number and y_pred is an open
+ # interval.
+ # interval_raw_prediction = Interval(-np.inf, np.inf, False, False)
+ interval_y_pred = Interval(-np.inf, np.inf, False, False)
+
+ @abstractmethod
+ def link(self, y_pred, out=None):
+ """Compute the link function g(y_pred).
+
+ The link function maps (predicted) target values to raw predictions,
+ i.e. `g(y_pred) = raw_prediction`.
+
+ Parameters
+ ----------
+ y_pred : array
+ Predicted target values.
+ out : array
+ A location into which the result is stored. If provided, it must
+ have a shape that the inputs broadcast to. If not provided or None,
+ a freshly-allocated array is returned.
+
+ Returns
+ -------
+ out : array
+ Output array, element-wise link function.
+ """
+
+ @abstractmethod
+ def inverse(self, raw_prediction, out=None):
+ """Compute the inverse link function h(raw_prediction).
+
+ The inverse link function maps raw predictions to predicted target
+ values, i.e. `h(raw_prediction) = y_pred`.
+
+ Parameters
+ ----------
+ raw_prediction : array
+ Raw prediction values (in link space).
+ out : array
+ A location into which the result is stored. If provided, it must
+ have a shape that the inputs broadcast to. If not provided or None,
+ a freshly-allocated array is returned.
+
+ Returns
+ -------
+ out : array
+ Output array, element-wise inverse link function.
+ """
+
+
+class IdentityLink(BaseLink):
+ """The identity link function g(x)=x."""
+
+ def link(self, y_pred, out=None):
+ if out is not None:
+ np.copyto(out, y_pred)
+ return out
+ else:
+ return y_pred
+
+ inverse = link
+
+
+class LogLink(BaseLink):
+ """The log link function g(x)=log(x)."""
+
+ interval_y_pred = Interval(0, np.inf, False, False)
+
+ def link(self, y_pred, out=None):
+ return np.log(y_pred, out=out)
+
+ def inverse(self, raw_prediction, out=None):
+ return np.exp(raw_prediction, out=out)
+
+
+class LogitLink(BaseLink):
+ """The logit link function g(x)=logit(x)."""
+
+ interval_y_pred = Interval(0, 1, False, False)
+
+ def link(self, y_pred, out=None):
+ return logit(y_pred, out=out)
+
+ def inverse(self, raw_prediction, out=None):
+ return expit(raw_prediction, out=out)
+
+
+class MultinomialLogit(BaseLink):
+ """The symmetric multinomial logit function.
+
+ Convention:
+ - y_pred.shape = raw_prediction.shape = (n_samples, n_classes)
+
+ Notes:
+ - The inverse link h is the softmax function.
+ - The sum is over the second axis, i.e. axis=1 (n_classes).
+
+ We have to choose additional contraints in order to make
+
+ y_pred[k] = exp(raw_pred[k]) / sum(exp(raw_pred[k]), k=0..n_classes-1)
+
+ for n_classes classes identifiable and invertible.
+ We choose the symmetric side contraint where the geometric mean response
+ is set as reference category, see [2]:
+
+ The symmetric multinomial logit link function for a single data point is
+ then defined as
+
+ raw_prediction[k] = g(y_pred[k]) = log(y_pred[k]/gmean(y_pred))
+ = log(y_pred[k]) - mean(log(y_pred)).
+
+ Note that this is equivalent to the definition in [1] and implies mean
+ centered raw predictions:
+
+ sum(raw_prediction[k], k=0..n_classes-1) = 0.
+
+ For linear models with raw_prediction = X @ coef, this corresponds to
+ sum(coef[k], k=0..n_classes-1) = 0, i.e. the sum over classes for every
+ feature is zero.
+
+ Reference
+ ---------
+ .. [1] Friedman, Jerome; Hastie, Trevor; Tibshirani, Robert. "Additive
+ logistic regression: a statistical view of boosting" Ann. Statist.
+ 28 (2000), no. 2, 337--407. doi:10.1214/aos/1016218223.
+ https://projecteuclid.org/euclid.aos/1016218223
+
+ .. [2] Zahid, Faisal Maqbool and Gerhard Tutz. "Ridge estimation for
+ multinomial logit models with symmetric side constraints."
+ Computational Statistics 28 (2013): 1017-1034.
+ http://epub.ub.uni-muenchen.de/11001/1/tr067.pdf
+ """
+
+ is_multiclass = True
+ interval_y_pred = Interval(0, 1, False, False)
+
+ def symmetrize_raw_prediction(self, raw_prediction):
+ return raw_prediction - np.mean(raw_prediction, axis=1)[:, np.newaxis]
+
+ def link(self, y_pred, out=None):
+ # geometric mean as reference category
+ gm = gmean(y_pred, axis=1)
+ return np.log(y_pred / gm[:, np.newaxis], out=out)
+
+ def inverse(self, raw_prediction, out=None):
+ if out is None:
+ return softmax(raw_prediction, copy=True)
+ else:
+ np.copyto(out, raw_prediction)
+ softmax(out, copy=False)
+ return out
+
+
+_LINKS = {
+ "identity": IdentityLink,
+ "log": LogLink,
+ "logit": LogitLink,
+ "multinomial_logit": MultinomialLogit,
+}
diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py
new file mode 100644
index 0000000000000..a394bd9de06c3
--- /dev/null
+++ b/sklearn/_loss/loss.py
@@ -0,0 +1,924 @@
+"""
+This module contains loss classes suitable for fitting.
+
+It is not part of the public API.
+Specific losses are used for regression, binary classification or multiclass
+classification.
+"""
+# Goals:
+# - Provide a common private module for loss functions/classes.
+# - To be used in:
+# - LogisticRegression
+# - PoissonRegressor, GammaRegressor, TweedieRegressor
+# - HistGradientBoostingRegressor, HistGradientBoostingClassifier
+# - GradientBoostingRegressor, GradientBoostingClassifier
+# - SGDRegressor, SGDClassifier
+# - Replace link module of GLMs.
+
+import numpy as np
+from scipy.special import xlogy
+from ._loss import (
+ CyHalfSquaredError,
+ CyAbsoluteError,
+ CyPinballLoss,
+ CyHalfPoissonLoss,
+ CyHalfGammaLoss,
+ CyHalfTweedieLoss,
+ CyHalfBinomialLoss,
+ CyHalfMultinomialLoss,
+)
+from .link import (
+ Interval,
+ IdentityLink,
+ LogLink,
+ LogitLink,
+ MultinomialLogit,
+)
+from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper
+from ..utils.stats import _weighted_percentile
+
+
+# Note: The shape of raw_prediction for multiclass classifications are
+# - GradientBoostingClassifier: (n_samples, n_classes)
+# - HistGradientBoostingClassifier: (n_classes, n_samples)
+#
+# Note: Instead of inheritance like
+#
+# class BaseLoss(BaseLink, CyLossFunction):
+# ...
+#
+# # Note: Naturally, we would inherit in the following order
+# # class HalfSquaredError(IdentityLink, CyHalfSquaredError, BaseLoss)
+# # But because of https://github.com/cython/cython/issues/4350 we set BaseLoss as
+# # the last one. This, of course, changes the MRO.
+# class HalfSquaredError(IdentityLink, CyHalfSquaredError, BaseLoss):
+#
+# we use composition. This way we improve maintainability by avoiding the above
+# mentioned Cython edge case and have easier to understand code (which method calls
+# which code).
+class BaseLoss:
+ """Base class for a loss function of 1-dimensional targets.
+
+ Conventions:
+
+ - y_true.shape = sample_weight.shape = (n_samples,)
+ - y_pred.shape = raw_prediction.shape = (n_samples,)
+ - If is_multiclass is true (multiclass classification), then
+ y_pred.shape = raw_prediction.shape = (n_samples, n_classes)
+ Note that this corresponds to the return value of decision_function.
+
+ y_true, y_pred, sample_weight and raw_prediction must either be all float64
+ or all float32.
+ gradient and hessian must be either both float64 or both float32.
+
+ Note that y_pred = link.inverse(raw_prediction).
+
+ Specific loss classes can inherit specific link classes to satisfy
+ BaseLink's abstractmethods.
+
+ Parameters
+ ----------
+ sample_weight : {None, ndarray}
+ If sample_weight is None, the hessian might be constant.
+ n_classes : {None, int}
+ The number of classes for classification, else None.
+
+ Attributes
+ ----------
+ closs: CyLossFunction
+ link : BaseLink
+ interval_y_true : Interval
+ Valid interval for y_true
+ interval_y_pred : Interval
+ Valid Interval for y_pred
+ differentiable : bool
+ Indicates whether or not loss function is differentiable in
+ raw_prediction everywhere.
+ need_update_leaves_values : bool
+ Indicates whether decision trees in gradient boosting need to uptade
+ leave values after having been fit to the (negative) gradients.
+ approx_hessian : bool
+ Indicates whether the hessian is approximated or exact. If,
+ approximated, it should be larger or equal to the exact one.
+ constant_hessian : bool
+ Indicates whether the hessian is one for this loss.
+ is_multiclass : bool
+ Indicates whether n_classes > 2 is allowed.
+ """
+
+ # For decision trees:
+ # This variable indicates whether the loss requires the leaves values to
+ # be updated once the tree has been trained. The trees are trained to
+ # predict a Newton-Raphson step (see grower._finalize_leaf()). But for
+ # some losses (e.g. least absolute deviation) we need to adjust the tree
+ # values to account for the "line search" of the gradient descent
+ # procedure. See the original paper Greedy Function Approximation: A
+ # Gradient Boosting Machine by Friedman
+ # (https://statweb.stanford.edu/~jhf/ftp/trebst.pdf) for the theory.
+ need_update_leaves_values = False
+ differentiable = True
+ is_multiclass = False
+
+ def __init__(self, closs, link, n_classes=1):
+ self.closs = closs
+ self.link = link
+ self.approx_hessian = False
+ self.constant_hessian = False
+ self.n_classes = n_classes
+ self.interval_y_true = Interval(-np.inf, np.inf, False, False)
+ self.interval_y_pred = self.link.interval_y_pred
+
+ def in_y_true_range(self, y):
+ """Return True if y is in the valid range of y_true.
+
+ Parameters
+ ----------
+ y : ndarray
+ """
+ return self.interval_y_true.includes(y)
+
+ def in_y_pred_range(self, y):
+ """Return True if y is in the valid range of y_pred.
+
+ Parameters
+ ----------
+ y : ndarray
+ """
+ return self.interval_y_pred.includes(y)
+
+ def loss(
+ self,
+ y_true,
+ raw_prediction,
+ sample_weight=None,
+ loss_out=None,
+ n_threads=1,
+ ):
+ """Compute the pointwise loss value for each input.
+
+ Parameters
+ ----------
+ y_true : C-contiguous array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : C-contiguous array of shape (n_samples,) or array of \
+ shape (n_samples, n_classes)
+ Raw prediction values (in link space).
+ sample_weight : None or C-contiguous array of shape (n_samples,)
+ Sample weights.
+ loss_out : None or C-contiguous array of shape (n_samples,)
+ A location into which the result is stored. If None, a new array
+ might be created.
+ n_threads : int, default=1
+ Might use openmp thread parallelism.
+
+ Returns
+ -------
+ loss : array of shape (n_samples,)
+ Element-wise loss function.
+ """
+ if loss_out is None:
+ loss_out = np.empty_like(y_true)
+ # Be graceful to shape (n_samples, 1) -> (n_samples,)
+ if raw_prediction.ndim == 2 and raw_prediction.shape[1] == 1:
+ raw_prediction = raw_prediction.squeeze(1)
+
+ y_true = ReadonlyArrayWrapper(y_true)
+ raw_prediction = ReadonlyArrayWrapper(raw_prediction)
+ if sample_weight is not None:
+ sample_weight = ReadonlyArrayWrapper(sample_weight)
+ return self.closs.loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=loss_out,
+ n_threads=n_threads,
+ )
+
+ def loss_gradient(
+ self,
+ y_true,
+ raw_prediction,
+ sample_weight=None,
+ loss_out=None,
+ gradient_out=None,
+ n_threads=1,
+ ):
+ """Compute loss and gradient w.r.t. raw_prediction for each input.
+
+ Parameters
+ ----------
+ y_true : C-contiguous array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : C-contiguous array of shape (n_samples,) or array of \
+ shape (n_samples, n_classes)
+ Raw prediction values (in link space).
+ sample_weight : None or C-contiguous array of shape (n_samples,)
+ Sample weights.
+ loss_out : None or C-contiguous array of shape (n_samples,)
+ A location into which the loss is stored. If None, a new array
+ might be created.
+ gradient_out : None or C-contiguous array of shape (n_samples,) or array \
+ of shape (n_samples, n_classes)
+ A location into which the gradient is stored. If None, a new array
+ might be created.
+ n_threads : int, default=1
+ Might use openmp thread parallelism.
+
+ Returns
+ -------
+ loss : array of shape (n_samples,)
+ Element-wise loss function.
+
+ gradient : array of shape (n_samples,) or (n_samples, n_classes)
+ Element-wise gradients.
+ """
+ if loss_out is None:
+ if gradient_out is None:
+ loss_out = np.empty_like(y_true)
+ gradient_out = np.empty_like(raw_prediction)
+ else:
+ loss_out = np.empty_like(y_true, dtype=gradient_out.dtype)
+ elif gradient_out is None:
+ gradient_out = np.empty_like(raw_prediction, dtype=loss_out.dtype)
+
+ # Be graceful to shape (n_samples, 1) -> (n_samples,)
+ if raw_prediction.ndim == 2 and raw_prediction.shape[1] == 1:
+ raw_prediction = raw_prediction.squeeze(1)
+ if gradient_out.ndim == 2 and gradient_out.shape[1] == 1:
+ gradient_out = gradient_out.squeeze(1)
+
+ y_true = ReadonlyArrayWrapper(y_true)
+ raw_prediction = ReadonlyArrayWrapper(raw_prediction)
+ if sample_weight is not None:
+ sample_weight = ReadonlyArrayWrapper(sample_weight)
+ return self.closs.loss_gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=loss_out,
+ gradient_out=gradient_out,
+ n_threads=n_threads,
+ )
+
+ def gradient(
+ self,
+ y_true,
+ raw_prediction,
+ sample_weight=None,
+ gradient_out=None,
+ n_threads=1,
+ ):
+ """Compute gradient of loss w.r.t raw_prediction for each input.
+
+ Parameters
+ ----------
+ y_true : C-contiguous array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : C-contiguous array of shape (n_samples,) or array of \
+ shape (n_samples, n_classes)
+ Raw prediction values (in link space).
+ sample_weight : None or C-contiguous array of shape (n_samples,)
+ Sample weights.
+ gradient_out : None or C-contiguous array of shape (n_samples,) or array \
+ of shape (n_samples, n_classes)
+ A location into which the result is stored. If None, a new array
+ might be created.
+ n_threads : int, default=1
+ Might use openmp thread parallelism.
+
+ Returns
+ -------
+ gradient : array of shape (n_samples,) or (n_samples, n_classes)
+ Element-wise gradients.
+ """
+ if gradient_out is None:
+ gradient_out = np.empty_like(raw_prediction)
+
+ # Be graceful to shape (n_samples, 1) -> (n_samples,)
+ if raw_prediction.ndim == 2 and raw_prediction.shape[1] == 1:
+ raw_prediction = raw_prediction.squeeze(1)
+ if gradient_out.ndim == 2 and gradient_out.shape[1] == 1:
+ gradient_out = gradient_out.squeeze(1)
+
+ y_true = ReadonlyArrayWrapper(y_true)
+ raw_prediction = ReadonlyArrayWrapper(raw_prediction)
+ if sample_weight is not None:
+ sample_weight = ReadonlyArrayWrapper(sample_weight)
+ return self.closs.gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=gradient_out,
+ n_threads=n_threads,
+ )
+
+ def gradient_hessian(
+ self,
+ y_true,
+ raw_prediction,
+ sample_weight=None,
+ gradient_out=None,
+ hessian_out=None,
+ n_threads=1,
+ ):
+ """Compute gradient and hessian of loss w.r.t raw_prediction.
+
+ Parameters
+ ----------
+ y_true : C-contiguous array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : C-contiguous array of shape (n_samples,) or array of \
+ shape (n_samples, n_classes)
+ Raw prediction values (in link space).
+ sample_weight : None or C-contiguous array of shape (n_samples,)
+ Sample weights.
+ gradient_out : None or C-contiguous array of shape (n_samples,) or array \
+ of shape (n_samples, n_classes)
+ A location into which the gradient is stored. If None, a new array
+ might be created.
+ hessian_out : None or C-contiguous array of shape (n_samples,) or array \
+ of shape (n_samples, n_classes)
+ A location into which the hessian is stored. If None, a new array
+ might be created.
+ n_threads : int, default=1
+ Might use openmp thread parallelism.
+
+ Returns
+ -------
+ gradient : arrays of shape (n_samples,) or (n_samples, n_classes)
+ Element-wise gradients.
+
+ hessian : arrays of shape (n_samples,) or (n_samples, n_classes)
+ Element-wise hessians.
+ """
+ if gradient_out is None:
+ if hessian_out is None:
+ gradient_out = np.empty_like(raw_prediction)
+ hessian_out = np.empty_like(raw_prediction)
+ else:
+ gradient_out = np.empty_like(hessian_out)
+ elif hessian_out is None:
+ hessian_out = np.empty_like(gradient_out)
+
+ # Be graceful to shape (n_samples, 1) -> (n_samples,)
+ if raw_prediction.ndim == 2 and raw_prediction.shape[1] == 1:
+ raw_prediction = raw_prediction.squeeze(1)
+ if gradient_out.ndim == 2 and gradient_out.shape[1] == 1:
+ gradient_out = gradient_out.squeeze(1)
+ if hessian_out.ndim == 2 and hessian_out.shape[1] == 1:
+ hessian_out = hessian_out.squeeze(1)
+
+ y_true = ReadonlyArrayWrapper(y_true)
+ raw_prediction = ReadonlyArrayWrapper(raw_prediction)
+ if sample_weight is not None:
+ sample_weight = ReadonlyArrayWrapper(sample_weight)
+ return self.closs.gradient_hessian(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=gradient_out,
+ hessian_out=hessian_out,
+ n_threads=n_threads,
+ )
+
+ def __call__(self, y_true, raw_prediction, sample_weight=None, n_threads=1):
+ """Compute the weighted average loss.
+
+ Parameters
+ ----------
+ y_true : C-contiguous array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : C-contiguous array of shape (n_samples,) or array of \
+ shape (n_samples, n_classes)
+ Raw prediction values (in link space).
+ sample_weight : None or C-contiguous array of shape (n_samples,)
+ Sample weights.
+ n_threads : int, default=1
+ Might use openmp thread parallelism.
+
+ Returns
+ -------
+ loss : float
+ Mean or averaged loss function.
+ """
+ return np.average(
+ self.loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=None,
+ loss_out=None,
+ n_threads=n_threads,
+ ),
+ weights=sample_weight,
+ )
+
+ def fit_intercept_only(self, y_true, sample_weight=None):
+ """Compute raw_prediction of an intercept-only model.
+
+ This can be used as initial estimates of predictions, i.e. before the
+ first iteration in fit.
+
+ Parameters
+ ----------
+ y_true : array-like of shape (n_samples,)
+ Observed, true target values.
+ sample_weight : None or array of shape (n_samples,)
+ Sample weights.
+
+ Returns
+ -------
+ raw_prediction : float or (n_classes,)
+ Raw predictions of an intercept-only model.
+ """
+ # As default, take weighted average of the target over the samples
+ # axis=0 and then transform into link-scale (raw_prediction).
+ y_pred = np.average(y_true, weights=sample_weight, axis=0)
+ eps = 10 * np.finfo(y_pred.dtype).eps
+
+ if self.interval_y_pred.low == -np.inf:
+ a_min = None
+ elif self.interval_y_pred.low_inclusive:
+ a_min = self.interval_y_pred.low
+ else:
+ a_min = self.interval_y_pred.low + eps
+
+ if self.interval_y_pred.high == np.inf:
+ a_max = None
+ elif self.interval_y_pred.high_inclusive:
+ a_max = self.interval_y_pred.high
+ else:
+ a_max = self.interval_y_pred.high - eps
+
+ if a_min is None and a_max is None:
+ return self.link.link(y_pred)
+ else:
+ return self.link.link(np.clip(y_pred, a_min, a_max))
+
+ def constant_to_optimal_zero(self, y_true, sample_weight=None):
+ """Calculate term dropped in loss.
+
+ With this term added, the loss of perfect predictions is zero.
+ """
+ return np.zeros_like(y_true)
+
+
+# Note: Naturally, we would inherit in the following order
+# class HalfSquaredError(IdentityLink, CyHalfSquaredError, BaseLoss)
+# But because of https://github.com/cython/cython/issues/4350 we
+# set BaseLoss as the last one. This, of course, changes the MRO.
+class HalfSquaredError(BaseLoss):
+ """Half squared error with identity link, for regression.
+
+ Domain:
+ y_true and y_pred all real numbers
+
+ Link:
+ y_pred = raw_prediction
+
+ For a given sample x_i, half squared error is defined as::
+
+ loss(x_i) = 0.5 * (y_true_i - raw_prediction_i)**2
+
+ The factor of 0.5 simplifies the computation of gradients and results in a
+ unit hessian (and is consistent with what is done in LightGBM). It is also
+ half the Normal distribution deviance.
+ """
+
+ def __init__(self, sample_weight=None):
+ super().__init__(closs=CyHalfSquaredError(), link=IdentityLink())
+ self.constant_hessian = sample_weight is None
+
+
+class AbsoluteError(BaseLoss):
+ """Absolute error with identity link, for regression.
+
+ Domain:
+ y_true and y_pred all real numbers
+
+ Link:
+ y_pred = raw_prediction
+
+ For a given sample x_i, the absolute error is defined as::
+
+ loss(x_i) = |y_true_i - raw_prediction_i|
+ """
+
+ differentiable = False
+ need_update_leaves_values = True
+
+ def __init__(self, sample_weight=None):
+ super().__init__(closs=CyAbsoluteError(), link=IdentityLink())
+ self.approx_hessian = True
+ self.constant_hessian = sample_weight is None
+
+ def fit_intercept_only(self, y_true, sample_weight=None):
+ """Compute raw_prediction of an intercept-only model.
+
+ This is the weighted median of the target, i.e. over the samples
+ axis=0.
+ """
+ if sample_weight is None:
+ return np.median(y_true, axis=0)
+ else:
+ return _weighted_percentile(y_true, sample_weight, 50)
+
+
+class PinballLoss(BaseLoss):
+ """Quantile loss aka pinball loss, for regression.
+
+ Domain:
+ y_true and y_pred all real numbers
+ quantile in (0, 1)
+
+ Link:
+ y_pred = raw_prediction
+
+ For a given sample x_i, the pinball loss is defined as::
+
+ loss(x_i) = rho_{quantile}(y_true_i - raw_prediction_i)
+
+ rho_{quantile}(u) = u * (quantile - 1_{u<0})
+ = -u *(1 - quantile) if u < 0
+ u * quantile if u >= 0
+
+ Note: 2 * PinballLoss(quantile=0.5) equals AbsoluteError().
+
+ Additional Attributes
+ ---------------------
+ quantile : float
+ The quantile to be estimated. Must be in range (0, 1).
+ """
+
+ differentiable = False
+ need_update_leaves_values = True
+
+ def __init__(self, sample_weight=None, quantile=0.5):
+ if quantile <= 0 or quantile >= 1:
+ raise ValueError(
+ "PinballLoss aka quantile loss only accepts "
+ f"0 < quantile < 1; {quantile} was given."
+ )
+ super().__init__(
+ closs=CyPinballLoss(quantile=float(quantile)),
+ link=IdentityLink(),
+ )
+ self.approx_hessian = True
+ self.constant_hessian = sample_weight is None
+
+ def fit_intercept_only(self, y_true, sample_weight=None):
+ """Compute raw_prediction of an intercept-only model.
+
+ This is the weighted median of the target, i.e. over the samples
+ axis=0.
+ """
+ if sample_weight is None:
+ return np.percentile(y_true, 100 * self.closs.quantile, axis=0)
+ else:
+ return _weighted_percentile(
+ y_true, sample_weight, 100 * self.closs.quantile
+ )
+
+
+class HalfPoissonLoss(BaseLoss):
+ """Half Poisson deviance loss with log-link, for regression.
+
+ Domain:
+ y_true in non-negative real numbers
+ y_pred in positive real numbers
+
+ Link:
+ y_pred = exp(raw_prediction)
+
+ For a given sample x_i, half the Poisson deviance is defined as::
+
+ loss(x_i) = y_true_i * log(y_true_i/exp(raw_prediction_i))
+ - y_true_i + exp(raw_prediction_i)
+
+ Half the Poisson deviance is actually the negative log-likelihood up to
+ constant terms (not involving raw_prediction) and simplifies the
+ computation of the gradients.
+ We also skip the constant term `y_true_i * log(y_true_i) - y_true_i`.
+ """
+
+ def __init__(self, sample_weight=None):
+ super().__init__(closs=CyHalfPoissonLoss(), link=LogLink())
+ self.interval_y_true = Interval(0, np.inf, True, False)
+
+ def constant_to_optimal_zero(self, y_true, sample_weight=None):
+ term = xlogy(y_true, y_true) - y_true
+ if sample_weight is not None:
+ term *= sample_weight
+ return term
+
+
+class HalfGammaLoss(BaseLoss):
+ """Half Gamma deviance loss with log-link, for regression.
+
+ Domain:
+ y_true and y_pred in positive real numbers
+
+ Link:
+ y_pred = exp(raw_prediction)
+
+ For a given sample x_i, half Gamma deviance loss is defined as::
+
+ loss(x_i) = log(exp(raw_prediction_i)/y_true_i)
+ + y_true/exp(raw_prediction_i) - 1
+
+ Half the Gamma deviance is actually proportional to the negative log-
+ likelihood up to constant terms (not involving raw_prediction) and
+ simplifies the computation of the gradients.
+ We also skip the constant term `-log(y_true_i) - 1`.
+ """
+
+ def __init__(self, sample_weight=None):
+ super().__init__(closs=CyHalfGammaLoss(), link=LogLink())
+ self.interval_y_true = Interval(0, np.inf, False, False)
+
+ def constant_to_optimal_zero(self, y_true, sample_weight=None):
+ term = -np.log(y_true) - 1
+ if sample_weight is not None:
+ term *= sample_weight
+ return term
+
+
+class HalfTweedieLoss(BaseLoss):
+ """Half Tweedie deviance loss with log-link, for regression.
+
+ Domain:
+ y_true in real numbers for power <= 0
+ y_true in non-negative real numbers for 0 < power < 2
+ y_true in positive real numbers for 2 <= power
+ y_pred in positive real numbers
+ power in real numbers
+
+ Link:
+ y_pred = exp(raw_prediction)
+
+ For a given sample x_i, half Tweedie deviance loss with p=power is defined
+ as::
+
+ loss(x_i) = max(y_true_i, 0)**(2-p) / (1-p) / (2-p)
+ - y_true_i * exp(raw_prediction_i)**(1-p) / (1-p)
+ + exp(raw_prediction_i)**(2-p) / (2-p)
+
+ Taking the limits for p=0, 1, 2 gives HalfSquaredError with a log link,
+ HalfPoissonLoss and HalfGammaLoss.
+
+ We also skip constant terms, but those are different for p=0, 1, 2.
+ Therefore, the loss is not continuous in `power`.
+
+ Note furthermore that although no Tweedie distribution exists for
+ 0 < power < 1, it still gives a strictly consistent scoring function for
+ the expectation.
+ """
+
+ def __init__(self, sample_weight=None, power=1.5):
+ super().__init__(
+ closs=CyHalfTweedieLoss(power=float(power)),
+ link=LogLink(),
+ )
+ if self.closs.power <= 0:
+ self.interval_y_true = Interval(-np.inf, np.inf, False, False)
+ elif self.closs.power < 2:
+ self.interval_y_true = Interval(0, np.inf, True, False)
+ else:
+ self.interval_y_true = Interval(0, np.inf, False, False)
+
+ def constant_to_optimal_zero(self, y_true, sample_weight=None):
+ if self.closs.power == 0:
+ return HalfSquaredError().constant_to_optimal_zero(
+ y_true=y_true, sample_weight=sample_weight
+ )
+ elif self.closs.power == 1:
+ return HalfPoissonLoss().constant_to_optimal_zero(
+ y_true=y_true, sample_weight=sample_weight
+ )
+ elif self.closs.power == 2:
+ return HalfGammaLoss().constant_to_optimal_zero(
+ y_true=y_true, sample_weight=sample_weight
+ )
+ else:
+ p = self.closs.power
+ term = np.power(np.maximum(y_true, 0), 2 - p) / (1 - p) / (2 - p)
+ if sample_weight is not None:
+ term *= sample_weight
+ return term
+
+
+class HalfBinomialLoss(BaseLoss):
+ """Half Binomial deviance loss with logit link, for binary classification.
+
+ This is also know as binary cross entropy, log-loss and logistic loss.
+
+ Domain:
+ y_true in [0, 1], i.e. regression on the unit interval
+ y_pred in (0, 1), i.e. boundaries excluded
+
+ Link:
+ y_pred = expit(raw_prediction)
+
+ For a given sample x_i, half Binomial deviance is defined as the negative
+ log-likelihood of the Binomial/Bernoulli distribution and can be expressed
+ as::
+
+ loss(x_i) = log(1 + exp(raw_pred_i)) - y_true_i * raw_pred_i
+
+ See The Elements of Statistical Learning, by Hastie, Tibshirani, Friedman,
+ section 4.4.1 (about logistic regression).
+
+ Note that the formulation works for classification, y = {0, 1}, as well as
+ logistic regression, y = [0, 1].
+ If you add `constant_to_optimal_zero` to the loss, you get half the
+ Bernoulli/binomial deviance.
+ """
+
+ def __init__(self, sample_weight=None):
+ super().__init__(
+ closs=CyHalfBinomialLoss(),
+ link=LogitLink(),
+ n_classes=2,
+ )
+ self.interval_y_true = Interval(0, 1, True, True)
+
+ def constant_to_optimal_zero(self, y_true, sample_weight=None):
+ # This is non-zero only if y_true is neither 0 nor 1.
+ term = xlogy(y_true, y_true) + xlogy(1 - y_true, 1 - y_true)
+ if sample_weight is not None:
+ term *= sample_weight
+ return term
+
+ def predict_proba(self, raw_prediction):
+ """Predict probabilities.
+
+ Parameters
+ ----------
+ raw_prediction : array of shape (n_samples,) or (n_samples, 1)
+ Raw prediction values (in link space).
+
+ Returns
+ -------
+ proba : array of shape (n_samples, 2)
+ Element-wise class probabilites.
+ """
+ # Be graceful to shape (n_samples, 1) -> (n_samples,)
+ if raw_prediction.ndim == 2 and raw_prediction.shape[1] == 1:
+ raw_prediction = raw_prediction.squeeze(1)
+ proba = np.empty((raw_prediction.shape[0], 2), dtype=raw_prediction.dtype)
+ proba[:, 1] = self.link.inverse(raw_prediction)
+ proba[:, 0] = 1 - proba[:, 1]
+ return proba
+
+
+class HalfMultinomialLoss(BaseLoss):
+ """Categorical cross-entropy loss, for multiclass classification.
+
+ Domain:
+ y_true in {0, 1, 2, 3, .., n_classes - 1}
+ y_pred has n_classes elements, each element in (0, 1)
+
+ Link:
+ y_pred = softmax(raw_prediction)
+
+ Note: We assume y_true to be already label encoded. The inverse link is
+ softmax. But the full link function is the symmetric multinomial logit
+ function.
+
+ For a given sample x_i, the categorical cross-entropy loss is defined as
+ the negative log-likelihood of the multinomial distribution, it
+ generalizes the binary cross-entropy to more than 2 classes::
+
+ loss_i = log(sum(exp(raw_pred_{i, k}), k=0..n_classes-1))
+ - sum(y_true_{i, k} * raw_pred_{i, k}, k=0..n_classes-1)
+
+ See [1].
+
+ Note that for the hessian, we calculate only the diagonal part in the
+ classes: If the full hessian for classes k and l and sample i is H_i_k_l,
+ we calculate H_i_k_k, i.e. k=l.
+
+ Reference
+ ---------
+ .. [1] Simon, Noah, J. Friedman and T. Hastie.
+ "A Blockwise Descent Algorithm for Group-penalized Multiresponse and
+ Multinomial Regression."
+ https://arxiv.org/pdf/1311.6529.pdf
+ """
+
+ is_multiclass = True
+
+ def __init__(self, sample_weight=None, n_classes=3):
+ super().__init__(
+ closs=CyHalfMultinomialLoss(),
+ link=MultinomialLogit(),
+ n_classes=n_classes,
+ )
+ self.interval_y_true = Interval(0, np.inf, True, False)
+ self.interval_y_pred = Interval(0, 1, False, False)
+
+ def in_y_true_range(self, y):
+ """Return True if y is in the valid range of y_true.
+
+ Parameters
+ ----------
+ y : ndarray
+ """
+ return self.interval_y_true.includes(y) and np.all(y.astype(int) == y)
+
+ def fit_intercept_only(self, y_true, sample_weight=None):
+ """Compute raw_prediction of an intercept-only model.
+
+ This is the softmax of the weighted average of the target, i.e. over
+ the samples axis=0.
+ """
+ out = np.zeros(self.n_classes, dtype=y_true.dtype)
+ eps = np.finfo(y_true.dtype).eps
+ for k in range(self.n_classes):
+ out[k] = np.average(y_true == k, weights=sample_weight, axis=0)
+ out[k] = np.clip(out[k], eps, 1 - eps)
+ return self.link.link(out[None, :]).reshape(-1)
+
+ def predict_proba(self, raw_prediction):
+ """Predict probabilities.
+
+ Parameters
+ ----------
+ raw_prediction : array of shape (n_samples, n_classes)
+ Raw prediction values (in link space).
+
+ Returns
+ -------
+ proba : array of shape (n_samples, n_classes)
+ Element-wise class probabilites.
+ """
+ return self.link.inverse(raw_prediction)
+
+ def gradient_proba(
+ self,
+ y_true,
+ raw_prediction,
+ sample_weight=None,
+ gradient_out=None,
+ proba_out=None,
+ n_threads=1,
+ ):
+ """Compute gradient and class probabilities fow raw_prediction.
+
+ Parameters
+ ----------
+ y_true : C-contiguous array of shape (n_samples,)
+ Observed, true target values.
+ raw_prediction : array of shape (n_samples, n_classes)
+ Raw prediction values (in link space).
+ sample_weight : None or C-contiguous array of shape (n_samples,)
+ Sample weights.
+ gradient_out : None or array of shape (n_samples, n_classes)
+ A location into which the gradient is stored. If None, a new array
+ might be created.
+ proba_out : None or array of shape (n_samples, n_classes)
+ A location into which the class probabilities are stored. If None,
+ a new array might be created.
+ n_threads : int, default=1
+ Might use openmp thread parallelism.
+
+ Returns
+ -------
+ gradient : array of shape (n_samples, n_classes)
+ Element-wise gradients.
+
+ proba : array of shape (n_samples, n_classes)
+ Element-wise class probabilites.
+ """
+ if gradient_out is None:
+ if proba_out is None:
+ gradient_out = np.empty_like(raw_prediction)
+ proba_out = np.empty_like(raw_prediction)
+ else:
+ gradient_out = np.empty_like(proba_out)
+ elif proba_out is None:
+ proba_out = np.empty_like(gradient_out)
+
+ y_true = ReadonlyArrayWrapper(y_true)
+ raw_prediction = ReadonlyArrayWrapper(raw_prediction)
+ if sample_weight is not None:
+ sample_weight = ReadonlyArrayWrapper(sample_weight)
+ return self.closs.gradient_proba(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=gradient_out,
+ proba_out=proba_out,
+ n_threads=n_threads,
+ )
+
+
+_LOSSES = {
+ "squared_error": HalfSquaredError,
+ "absolute_error": AbsoluteError,
+ "pinball_loss": PinballLoss,
+ "poisson_loss": HalfPoissonLoss,
+ "gamma_loss": HalfGammaLoss,
+ "tweedie_loss": HalfTweedieLoss,
+ "binomial_loss": HalfBinomialLoss,
+ "multinomial_loss": HalfMultinomialLoss,
+}
diff --git a/sklearn/_loss/setup.py b/sklearn/_loss/setup.py
new file mode 100644
index 0000000000000..2a2d2b5f13b8a
--- /dev/null
+++ b/sklearn/_loss/setup.py
@@ -0,0 +1,25 @@
+import numpy
+from numpy.distutils.misc_util import Configuration
+from sklearn._build_utils import gen_from_templates
+
+
+def configuration(parent_package="", top_path=None):
+ config = Configuration("_loss", parent_package, top_path)
+
+ # generate _loss.pyx from template
+ templates = ["sklearn/_loss/_loss.pyx.tp"]
+ gen_from_templates(templates)
+
+ config.add_extension(
+ "_loss",
+ sources=["_loss.pyx"],
+ include_dirs=[numpy.get_include()],
+ # define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")],
+ )
+ return config
+
+
+if __name__ == "__main__":
+ from numpy.distutils.core import setup
+
+ setup(**configuration().todict())
diff --git a/sklearn/_loss/tests/test_link.py b/sklearn/_loss/tests/test_link.py
new file mode 100644
index 0000000000000..b363a45109989
--- /dev/null
+++ b/sklearn/_loss/tests/test_link.py
@@ -0,0 +1,108 @@
+import numpy as np
+from numpy.testing import assert_allclose, assert_array_equal
+import pytest
+
+from sklearn._loss.link import (
+ _LINKS,
+ _inclusive_low_high,
+ MultinomialLogit,
+ Interval,
+)
+
+
+LINK_FUNCTIONS = list(_LINKS.values())
+
+
+def test_interval_raises():
+ """Test that interval with low > high raises ValueError."""
+ with pytest.raises(
+ ValueError, match="On must have low <= high; got low=1, high=0."
+ ):
+ Interval(1, 0, False, False)
+
+
+@pytest.mark.parametrize(
+ "interval",
+ [
+ Interval(0, 1, False, False),
+ Interval(0, 1, False, True),
+ Interval(0, 1, True, False),
+ Interval(0, 1, True, True),
+ Interval(-np.inf, np.inf, False, False),
+ Interval(-np.inf, np.inf, False, True),
+ Interval(-np.inf, np.inf, True, False),
+ Interval(-np.inf, np.inf, True, True),
+ Interval(-10, -1, False, False),
+ Interval(-10, -1, False, True),
+ Interval(-10, -1, True, False),
+ Interval(-10, -1, True, True),
+ ],
+)
+def test_is_in_range(interval):
+ # make sure low and high are always within the interval, used for linspace
+ low, high = _inclusive_low_high(interval)
+
+ x = np.linspace(low, high, num=10)
+ assert interval.includes(x)
+
+ # x contains lower bound
+ assert interval.includes(np.r_[x, interval.low]) == interval.low_inclusive
+
+ # x contains upper bound
+ assert interval.includes(np.r_[x, interval.high]) == interval.high_inclusive
+
+ # x contains upper and lower bound
+ assert interval.includes(np.r_[x, interval.low, interval.high]) == (
+ interval.low_inclusive and interval.high_inclusive
+ )
+
+
+@pytest.mark.parametrize("link", LINK_FUNCTIONS)
+def test_link_inverse_identity(link):
+ # Test that link of inverse gives identity.
+ rng = np.random.RandomState(42)
+ link = link()
+ n_samples, n_classes = 100, None
+ if link.is_multiclass:
+ n_classes = 10
+ raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples, n_classes))
+ if isinstance(link, MultinomialLogit):
+ raw_prediction = link.symmetrize_raw_prediction(raw_prediction)
+ else:
+ # So far, the valid interval of raw_prediction is (-inf, inf) and
+ # we do not need to distinguish.
+ raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples))
+
+ assert_allclose(link.link(link.inverse(raw_prediction)), raw_prediction)
+ y_pred = link.inverse(raw_prediction)
+ assert_allclose(link.inverse(link.link(y_pred)), y_pred)
+
+
+@pytest.mark.parametrize("link", LINK_FUNCTIONS)
+def test_link_out_argument(link):
+ # Test that out argument gets assigned the result.
+ rng = np.random.RandomState(42)
+ link = link()
+ n_samples, n_classes = 100, None
+ if link.is_multiclass:
+ n_classes = 10
+ raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples, n_classes))
+ if isinstance(link, MultinomialLogit):
+ raw_prediction = link.symmetrize_raw_prediction(raw_prediction)
+ else:
+ # So far, the valid interval of raw_prediction is (-inf, inf) and
+ # we do not need to distinguish.
+ raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples))
+
+ y_pred = link.inverse(raw_prediction, out=None)
+ out = np.empty_like(raw_prediction)
+ y_pred_2 = link.inverse(raw_prediction, out=out)
+ assert_allclose(y_pred, out)
+ assert_array_equal(out, y_pred_2)
+ assert np.shares_memory(out, y_pred_2)
+
+ out = np.empty_like(y_pred)
+ raw_prediction_2 = link.link(y_pred, out=out)
+ assert_allclose(raw_prediction, out)
+ assert_array_equal(out, raw_prediction_2)
+ assert np.shares_memory(out, raw_prediction_2)
diff --git a/sklearn/_loss/tests/test_loss.py b/sklearn/_loss/tests/test_loss.py
new file mode 100644
index 0000000000000..2ad5633037c4a
--- /dev/null
+++ b/sklearn/_loss/tests/test_loss.py
@@ -0,0 +1,1009 @@
+import pickle
+
+import numpy as np
+from numpy.testing import assert_allclose, assert_array_equal
+import pytest
+from pytest import approx
+from scipy.optimize import (
+ minimize,
+ minimize_scalar,
+ newton,
+)
+from scipy.special import logsumexp
+
+from sklearn._loss.link import _inclusive_low_high, IdentityLink
+from sklearn._loss.loss import (
+ _LOSSES,
+ BaseLoss,
+ AbsoluteError,
+ HalfBinomialLoss,
+ HalfGammaLoss,
+ HalfMultinomialLoss,
+ HalfPoissonLoss,
+ HalfSquaredError,
+ HalfTweedieLoss,
+ PinballLoss,
+)
+from sklearn.utils import assert_all_finite
+from sklearn.utils._testing import create_memmap_backed_data, skip_if_32bit
+from sklearn.utils.fixes import sp_version, parse_version
+
+
+ALL_LOSSES = list(_LOSSES.values())
+
+LOSS_INSTANCES = [loss() for loss in ALL_LOSSES]
+# HalfTweedieLoss(power=1.5) is already there as default
+LOSS_INSTANCES += [
+ PinballLoss(quantile=0.25),
+ HalfTweedieLoss(power=-1.5),
+ HalfTweedieLoss(power=0),
+ HalfTweedieLoss(power=1),
+ HalfTweedieLoss(power=2),
+ HalfTweedieLoss(power=3.0),
+]
+
+
+def loss_instance_name(param):
+ if isinstance(param, BaseLoss):
+ loss = param
+ name = loss.__class__.__name__
+ if hasattr(loss, "quantile"):
+ name += f"(quantile={loss.closs.quantile})"
+ elif hasattr(loss, "power"):
+ name += f"(power={loss.closs.power})"
+ return name
+ else:
+ return str(param)
+
+
+def random_y_true_raw_prediction(
+ loss, n_samples, y_bound=(-100, 100), raw_bound=(-5, 5), seed=42
+):
+ """Random generate y_true and raw_prediction in valid range."""
+ rng = np.random.RandomState(seed)
+ if loss.is_multiclass:
+ raw_prediction = np.empty((n_samples, loss.n_classes))
+ raw_prediction.flat[:] = rng.uniform(
+ low=raw_bound[0],
+ high=raw_bound[1],
+ size=n_samples * loss.n_classes,
+ )
+ y_true = np.arange(n_samples).astype(float) % loss.n_classes
+ else:
+ raw_prediction = rng.uniform(
+ low=raw_bound[0], high=raw_bound[0], size=n_samples
+ )
+ # generate a y_true in valid range
+ low, high = _inclusive_low_high(loss.interval_y_true)
+ low = max(low, y_bound[0])
+ high = min(high, y_bound[1])
+ y_true = rng.uniform(low, high, size=n_samples)
+ # set some values at special boundaries
+ if loss.interval_y_true.low == 0 and loss.interval_y_true.low_inclusive:
+ y_true[:: (n_samples // 3)] = 0
+ if loss.interval_y_true.high == 1 and loss.interval_y_true.high_inclusive:
+ y_true[1 :: (n_samples // 3)] = 1
+
+ return y_true, raw_prediction
+
+
+def numerical_derivative(func, x, eps):
+ """Helper function for numerical (first) derivatives."""
+ # For numerical derivatives, see
+ # https://en.wikipedia.org/wiki/Numerical_differentiation
+ # https://en.wikipedia.org/wiki/Finite_difference_coefficient
+ # We use central finite differences of accuracy 4.
+ h = np.full_like(x, fill_value=eps)
+ f_minus_2h = func(x - 2 * h)
+ f_minus_1h = func(x - h)
+ f_plus_1h = func(x + h)
+ f_plus_2h = func(x + 2 * h)
+ return (-f_plus_2h + 8 * f_plus_1h - 8 * f_minus_1h + f_minus_2h) / (12.0 * eps)
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+def test_loss_boundary(loss):
+ """Test interval ranges of y_true and y_pred in losses."""
+ # make sure low and high are always within the interval, used for linspace
+ if loss.is_multiclass:
+ y_true = np.linspace(0, 9, num=10)
+ else:
+ low, high = _inclusive_low_high(loss.interval_y_true)
+ y_true = np.linspace(low, high, num=10)
+
+ # add boundaries if they are included
+ if loss.interval_y_true.low_inclusive:
+ y_true = np.r_[y_true, loss.interval_y_true.low]
+ if loss.interval_y_true.high_inclusive:
+ y_true = np.r_[y_true, loss.interval_y_true.high]
+
+ assert loss.in_y_true_range(y_true)
+
+ n = y_true.shape[0]
+ low, high = _inclusive_low_high(loss.interval_y_pred)
+ if loss.is_multiclass:
+ y_pred = np.empty((n, 3))
+ y_pred[:, 0] = np.linspace(low, high, num=n)
+ y_pred[:, 1] = 0.5 * (1 - y_pred[:, 0])
+ y_pred[:, 2] = 0.5 * (1 - y_pred[:, 0])
+ else:
+ y_pred = np.linspace(low, high, num=n)
+
+ assert loss.in_y_pred_range(y_pred)
+
+ # calculating losses should not fail
+ raw_prediction = loss.link.link(y_pred)
+ loss.loss(y_true=y_true, raw_prediction=raw_prediction)
+
+
+# Fixture to test valid value ranges.
+Y_COMMON_PARAMS = [
+ # (loss, [y success], [y fail])
+ (HalfSquaredError(), [-100, 0, 0.1, 100], [-np.inf, np.inf]),
+ (AbsoluteError(), [-100, 0, 0.1, 100], [-np.inf, np.inf]),
+ (PinballLoss(), [-100, 0, 0.1, 100], [-np.inf, np.inf]),
+ (HalfPoissonLoss(), [0.1, 100], [-np.inf, -3, -0.1, np.inf]),
+ (HalfGammaLoss(), [0.1, 100], [-np.inf, -3, -0.1, 0, np.inf]),
+ (HalfTweedieLoss(power=-3), [0.1, 100], [-np.inf, np.inf]),
+ (HalfTweedieLoss(power=0), [0.1, 100], [-np.inf, np.inf]),
+ (HalfTweedieLoss(power=1.5), [0.1, 100], [-np.inf, -3, -0.1, np.inf]),
+ (HalfTweedieLoss(power=2), [0.1, 100], [-np.inf, -3, -0.1, 0, np.inf]),
+ (HalfTweedieLoss(power=3), [0.1, 100], [-np.inf, -3, -0.1, 0, np.inf]),
+ (HalfBinomialLoss(), [0.1, 0.5, 0.9], [-np.inf, -1, 2, np.inf]),
+ (HalfMultinomialLoss(), [], [-np.inf, -1, 1.1, np.inf]),
+]
+# y_pred and y_true do not always have the same domain (valid value range).
+# Hence, we define extra sets of parameters for each of them.
+Y_TRUE_PARAMS = [ # type: ignore
+ # (loss, [y success], [y fail])
+ (HalfPoissonLoss(), [0], []),
+ (HalfTweedieLoss(power=-3), [-100, -0.1, 0], []),
+ (HalfTweedieLoss(power=0), [-100, 0], []),
+ (HalfTweedieLoss(power=1.5), [0], []),
+ (HalfBinomialLoss(), [0, 1], []),
+ (HalfMultinomialLoss(), [0.0, 1.0, 2], []),
+]
+Y_PRED_PARAMS = [
+ # (loss, [y success], [y fail])
+ (HalfPoissonLoss(), [], [0]),
+ (HalfTweedieLoss(power=-3), [], [-3, -0.1, 0]),
+ (HalfTweedieLoss(power=0), [], [-3, -0.1, 0]),
+ (HalfTweedieLoss(power=1.5), [], [0]),
+ (HalfBinomialLoss(), [], [0, 1]),
+ (HalfMultinomialLoss(), [0.1, 0.5], [0, 1]),
+]
+
+
+@pytest.mark.parametrize(
+ "loss, y_true_success, y_true_fail", Y_COMMON_PARAMS + Y_TRUE_PARAMS
+)
+def test_loss_boundary_y_true(loss, y_true_success, y_true_fail):
+ """Test boundaries of y_true for loss functions."""
+ for y in y_true_success:
+ assert loss.in_y_true_range(np.array([y]))
+ for y in y_true_fail:
+ assert not loss.in_y_true_range(np.array([y]))
+
+
+@pytest.mark.parametrize(
+ "loss, y_pred_success, y_pred_fail", Y_COMMON_PARAMS + Y_PRED_PARAMS # type: ignore
+)
+def test_loss_boundary_y_pred(loss, y_pred_success, y_pred_fail):
+ """Test boundaries of y_pred for loss functions."""
+ for y in y_pred_success:
+ assert loss.in_y_pred_range(np.array([y]))
+ for y in y_pred_fail:
+ assert not loss.in_y_pred_range(np.array([y]))
+
+
+@pytest.mark.parametrize(
+ "loss, y_true, raw_prediction, loss_true",
+ [
+ (HalfSquaredError(), 1.0, 5.0, 8),
+ (AbsoluteError(), 1.0, 5.0, 4),
+ (PinballLoss(quantile=0.5), 1.0, 5.0, 2),
+ (PinballLoss(quantile=0.25), 1.0, 5.0, 4 * (1 - 0.25)),
+ (PinballLoss(quantile=0.25), 5.0, 1.0, 4 * 0.25),
+ (HalfPoissonLoss(), 2.0, np.log(4), 4 - 2 * np.log(4)),
+ (HalfGammaLoss(), 2.0, np.log(4), np.log(4) + 2 / 4),
+ (HalfTweedieLoss(power=3), 2.0, np.log(4), -1 / 4 + 1 / 4 ** 2),
+ (HalfBinomialLoss(), 0.25, np.log(4), np.log(5) - 0.25 * np.log(4)),
+ (
+ HalfMultinomialLoss(n_classes=3),
+ 0.0,
+ [0.2, 0.5, 0.3],
+ logsumexp([0.2, 0.5, 0.3]) - 0.2,
+ ),
+ (
+ HalfMultinomialLoss(n_classes=3),
+ 1.0,
+ [0.2, 0.5, 0.3],
+ logsumexp([0.2, 0.5, 0.3]) - 0.5,
+ ),
+ (
+ HalfMultinomialLoss(n_classes=3),
+ 2.0,
+ [0.2, 0.5, 0.3],
+ logsumexp([0.2, 0.5, 0.3]) - 0.3,
+ ),
+ ],
+ ids=loss_instance_name,
+)
+def test_loss_on_specific_values(loss, y_true, raw_prediction, loss_true):
+ """Test losses at specific values."""
+ assert loss(
+ y_true=np.array([y_true]), raw_prediction=np.array([raw_prediction])
+ ) == approx(loss_true, rel=1e-11, abs=1e-12)
+
+
+@pytest.mark.parametrize("loss", ALL_LOSSES)
+@pytest.mark.parametrize("readonly_memmap", [False, True])
+@pytest.mark.parametrize("dtype_in", [np.float32, np.float64])
+@pytest.mark.parametrize("dtype_out", [np.float32, np.float64])
+@pytest.mark.parametrize("sample_weight", [None, 1])
+@pytest.mark.parametrize("out1", [None, 1])
+@pytest.mark.parametrize("out2", [None, 1])
+@pytest.mark.parametrize("n_threads", [1, 2])
+def test_loss_dtype(
+ loss, readonly_memmap, dtype_in, dtype_out, sample_weight, out1, out2, n_threads
+):
+ """Test acceptance of dtypes, readonly and writeable arrays in loss functions.
+
+ Check that loss accepts if all input arrays are either all float32 or all
+ float64, and all output arrays are either all float32 or all float64.
+
+ Also check that input arrays can be readonly, e.g. memory mapped.
+ """
+ loss = loss()
+ # generate a y_true and raw_prediction in valid range
+ n_samples = 5
+ y_true, raw_prediction = random_y_true_raw_prediction(
+ loss=loss,
+ n_samples=n_samples,
+ y_bound=(-100, 100),
+ raw_bound=(-10, 10),
+ seed=42,
+ )
+ y_true = y_true.astype(dtype_in)
+ raw_prediction = raw_prediction.astype(dtype_in)
+
+ if sample_weight is not None:
+ sample_weight = np.array([2.0] * n_samples, dtype=dtype_in)
+ if out1 is not None:
+ out1 = np.empty_like(y_true, dtype=dtype_out)
+ if out2 is not None:
+ out2 = np.empty_like(raw_prediction, dtype=dtype_out)
+
+ if readonly_memmap:
+ y_true = create_memmap_backed_data(y_true, aligned=True)
+ raw_prediction = create_memmap_backed_data(raw_prediction, aligned=True)
+ if sample_weight is not None:
+ sample_weight = create_memmap_backed_data(sample_weight, aligned=True)
+
+ loss.loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=out1,
+ n_threads=n_threads,
+ )
+ loss.gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out2,
+ n_threads=n_threads,
+ )
+ loss.loss_gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=out1,
+ gradient_out=out2,
+ n_threads=n_threads,
+ )
+ if out1 is not None and loss.is_multiclass:
+ out1 = np.empty_like(raw_prediction, dtype=dtype_out)
+ loss.gradient_hessian(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out1,
+ hessian_out=out2,
+ n_threads=n_threads,
+ )
+ loss(y_true=y_true, raw_prediction=raw_prediction, sample_weight=sample_weight)
+ loss.fit_intercept_only(y_true=y_true, sample_weight=sample_weight)
+ loss.constant_to_optimal_zero(y_true=y_true, sample_weight=sample_weight)
+ if hasattr(loss, "predict_proba"):
+ loss.predict_proba(raw_prediction=raw_prediction)
+ if hasattr(loss, "gradient_proba"):
+ loss.gradient_proba(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out1,
+ proba_out=out2,
+ n_threads=n_threads,
+ )
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+@pytest.mark.parametrize("sample_weight", [None, "range"])
+def test_loss_same_as_C_functions(loss, sample_weight):
+ """Test that Python and Cython functions return same results."""
+ y_true, raw_prediction = random_y_true_raw_prediction(
+ loss=loss,
+ n_samples=20,
+ y_bound=(-100, 100),
+ raw_bound=(-10, 10),
+ seed=42,
+ )
+ if sample_weight == "range":
+ sample_weight = np.linspace(1, y_true.shape[0], num=y_true.shape[0])
+
+ out_l1 = np.empty_like(y_true)
+ out_l2 = np.empty_like(y_true)
+ out_g1 = np.empty_like(raw_prediction)
+ out_g2 = np.empty_like(raw_prediction)
+ out_h1 = np.empty_like(raw_prediction)
+ out_h2 = np.empty_like(raw_prediction)
+ assert_allclose(
+ loss.loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=out_l1,
+ ),
+ loss.closs.loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=out_l2,
+ ),
+ )
+ assert_allclose(
+ loss.gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out_g1,
+ ),
+ loss.closs.gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out_g2,
+ ),
+ )
+ loss.closs.loss_gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=out_l1,
+ gradient_out=out_g1,
+ )
+ loss.closs.loss_gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=out_l2,
+ gradient_out=out_g2,
+ )
+ assert_allclose(out_l1, out_l2)
+ assert_allclose(out_g1, out_g2)
+ loss.gradient_hessian(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out_g1,
+ hessian_out=out_h1,
+ )
+ loss.closs.gradient_hessian(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out_g2,
+ hessian_out=out_h2,
+ )
+ assert_allclose(out_g1, out_g2)
+ assert_allclose(out_h1, out_h2)
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+@pytest.mark.parametrize("sample_weight", [None, "range"])
+def test_loss_gradients_are_the_same(loss, sample_weight):
+ """Test that loss and gradient are the same across different functions.
+
+ Also test that output arguments contain correct results.
+ """
+ y_true, raw_prediction = random_y_true_raw_prediction(
+ loss=loss,
+ n_samples=20,
+ y_bound=(-100, 100),
+ raw_bound=(-10, 10),
+ seed=42,
+ )
+ if sample_weight == "range":
+ sample_weight = np.linspace(1, y_true.shape[0], num=y_true.shape[0])
+
+ out_l1 = np.empty_like(y_true)
+ out_l2 = np.empty_like(y_true)
+ out_g1 = np.empty_like(raw_prediction)
+ out_g2 = np.empty_like(raw_prediction)
+ out_g3 = np.empty_like(raw_prediction)
+ out_h3 = np.empty_like(raw_prediction)
+
+ l1 = loss.loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=out_l1,
+ )
+ g1 = loss.gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out_g1,
+ )
+ l2, g2 = loss.loss_gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ loss_out=out_l2,
+ gradient_out=out_g2,
+ )
+ g3, h3 = loss.gradient_hessian(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out_g3,
+ hessian_out=out_h3,
+ )
+ assert_allclose(l1, l2)
+ assert_array_equal(l1, out_l1)
+ assert np.shares_memory(l1, out_l1)
+ assert_array_equal(l2, out_l2)
+ assert np.shares_memory(l2, out_l2)
+ assert_allclose(g1, g2)
+ assert_allclose(g1, g3)
+ assert_array_equal(g1, out_g1)
+ assert np.shares_memory(g1, out_g1)
+ assert_array_equal(g2, out_g2)
+ assert np.shares_memory(g2, out_g2)
+ assert_array_equal(g3, out_g3)
+ assert np.shares_memory(g3, out_g3)
+
+ if hasattr(loss, "gradient_proba"):
+ assert loss.is_multiclass # only for HalfMultinomialLoss
+ out_g4 = np.empty_like(raw_prediction)
+ out_proba = np.empty_like(raw_prediction)
+ g4, proba = loss.gradient_proba(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ gradient_out=out_g4,
+ proba_out=out_proba,
+ )
+ assert_allclose(g1, out_g4)
+ assert_allclose(g1, g4)
+ assert_allclose(proba, out_proba)
+ assert_allclose(np.sum(proba, axis=1), 1, rtol=1e-11)
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+@pytest.mark.parametrize("sample_weight", ["ones", "random"])
+def test_sample_weight_multiplies(loss, sample_weight):
+ """Test sample weights in loss, gradients and hessians.
+
+ Make sure that passing sample weights to loss, gradient and hessian
+ computation methods is equivalent to multiplying by the weights.
+ """
+ n_samples = 100
+ y_true, raw_prediction = random_y_true_raw_prediction(
+ loss=loss,
+ n_samples=n_samples,
+ y_bound=(-100, 100),
+ raw_bound=(-5, 5),
+ seed=42,
+ )
+
+ if sample_weight == "ones":
+ sample_weight = np.ones(shape=n_samples, dtype=np.float64)
+ else:
+ rng = np.random.RandomState(42)
+ sample_weight = rng.normal(size=n_samples).astype(np.float64)
+
+ assert_allclose(
+ loss.loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ ),
+ sample_weight
+ * loss.loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=None,
+ ),
+ )
+
+ losses, gradient = loss.loss_gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=None,
+ )
+ losses_sw, gradient_sw = loss.loss_gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ )
+ assert_allclose(losses * sample_weight, losses_sw)
+ if not loss.is_multiclass:
+ assert_allclose(gradient * sample_weight, gradient_sw)
+ else:
+ assert_allclose(gradient * sample_weight[:, None], gradient_sw)
+
+ gradient, hessian = loss.gradient_hessian(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=None,
+ )
+ gradient_sw, hessian_sw = loss.gradient_hessian(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ )
+ if not loss.is_multiclass:
+ assert_allclose(gradient * sample_weight, gradient_sw)
+ assert_allclose(hessian * sample_weight, hessian_sw)
+ else:
+ assert_allclose(gradient * sample_weight[:, None], gradient_sw)
+ assert_allclose(hessian * sample_weight[:, None], hessian_sw)
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+def test_graceful_squeezing(loss):
+ """Test that reshaped raw_prediction gives same results."""
+ y_true, raw_prediction = random_y_true_raw_prediction(
+ loss=loss,
+ n_samples=20,
+ y_bound=(-100, 100),
+ raw_bound=(-10, 10),
+ seed=42,
+ )
+
+ if raw_prediction.ndim == 1:
+ raw_prediction_2d = raw_prediction[:, None]
+ assert_allclose(
+ loss.loss(y_true=y_true, raw_prediction=raw_prediction_2d),
+ loss.loss(y_true=y_true, raw_prediction=raw_prediction),
+ )
+ assert_allclose(
+ loss.loss_gradient(y_true=y_true, raw_prediction=raw_prediction_2d),
+ loss.loss_gradient(y_true=y_true, raw_prediction=raw_prediction),
+ )
+ assert_allclose(
+ loss.gradient(y_true=y_true, raw_prediction=raw_prediction_2d),
+ loss.gradient(y_true=y_true, raw_prediction=raw_prediction),
+ )
+ assert_allclose(
+ loss.gradient_hessian(y_true=y_true, raw_prediction=raw_prediction_2d),
+ loss.gradient_hessian(y_true=y_true, raw_prediction=raw_prediction),
+ )
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+@pytest.mark.parametrize("sample_weight", [None, "range"])
+def test_loss_of_perfect_prediction(loss, sample_weight):
+ """Test value of perfect predictions.
+
+ Loss of y_pred = y_true plus constant_to_optimal_zero should sums up to
+ zero.
+ """
+ if not loss.is_multiclass:
+ # Use small values such that exp(value) is not nan.
+ raw_prediction = np.array([-10, -0.1, 0, 0.1, 3, 10])
+ y_true = loss.link.inverse(raw_prediction)
+ else:
+ # HalfMultinomialLoss
+ y_true = np.arange(loss.n_classes).astype(float)
+ # raw_prediction with entries -exp(10), but +exp(10) on the diagonal
+ # this is close enough to np.inf which would produce nan
+ raw_prediction = np.full(
+ shape=(loss.n_classes, loss.n_classes),
+ fill_value=-np.exp(10),
+ dtype=float,
+ )
+ raw_prediction.flat[:: loss.n_classes + 1] = np.exp(10)
+
+ if sample_weight == "range":
+ sample_weight = np.linspace(1, y_true.shape[0], num=y_true.shape[0])
+
+ loss_value = loss.loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ )
+ constant_term = loss.constant_to_optimal_zero(
+ y_true=y_true, sample_weight=sample_weight
+ )
+ # Comparing loss_value + constant_term to zero would result in large
+ # round-off errors.
+ assert_allclose(loss_value, -constant_term, atol=1e-14, rtol=1e-15)
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+@pytest.mark.parametrize("sample_weight", [None, "range"])
+def test_gradients_hessians_numerically(loss, sample_weight):
+ """Test gradients and hessians with numerical derivatives.
+
+ Gradient should equal the numerical derivatives of the loss function.
+ Hessians should equal the numerical derivatives of gradients.
+ """
+ n_samples = 20
+ y_true, raw_prediction = random_y_true_raw_prediction(
+ loss=loss,
+ n_samples=n_samples,
+ y_bound=(-100, 100),
+ raw_bound=(-5, 5),
+ seed=42,
+ )
+
+ if sample_weight == "range":
+ sample_weight = np.linspace(1, y_true.shape[0], num=y_true.shape[0])
+
+ g, h = loss.gradient_hessian(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ )
+
+ assert g.shape == raw_prediction.shape
+ assert h.shape == raw_prediction.shape
+
+ if not loss.is_multiclass:
+
+ def loss_func(x):
+ return loss.loss(
+ y_true=y_true,
+ raw_prediction=x,
+ sample_weight=sample_weight,
+ )
+
+ g_numeric = numerical_derivative(loss_func, raw_prediction, eps=1e-6)
+ assert_allclose(g, g_numeric, rtol=5e-6, atol=1e-10)
+
+ def grad_func(x):
+ return loss.gradient(
+ y_true=y_true,
+ raw_prediction=x,
+ sample_weight=sample_weight,
+ )
+
+ h_numeric = numerical_derivative(grad_func, raw_prediction, eps=1e-6)
+ if loss.approx_hessian:
+ # TODO: What could we test if loss.approx_hessian?
+ pass
+ else:
+ assert_allclose(h, h_numeric, rtol=5e-6, atol=1e-10)
+ else:
+ # For multiclass loss, we should only change the predictions of the
+ # class for which the derivative is taken for, e.g. offset[:, k] = eps
+ # for class k.
+ # As a softmax is computed, offsetting the whole array by a constant
+ # would have no effect on the probabilities, and thus on the loss.
+ for k in range(loss.n_classes):
+
+ def loss_func(x):
+ raw = raw_prediction.copy()
+ raw[:, k] = x
+ return loss.loss(
+ y_true=y_true,
+ raw_prediction=raw,
+ sample_weight=sample_weight,
+ )
+
+ g_numeric = numerical_derivative(loss_func, raw_prediction[:, k], eps=1e-5)
+ assert_allclose(g[:, k], g_numeric, rtol=5e-6, atol=1e-10)
+
+ def grad_func(x):
+ raw = raw_prediction.copy()
+ raw[:, k] = x
+ return loss.gradient(
+ y_true=y_true,
+ raw_prediction=raw,
+ sample_weight=sample_weight,
+ )[:, k]
+
+ h_numeric = numerical_derivative(grad_func, raw_prediction[:, k], eps=1e-6)
+ if loss.approx_hessian:
+ # TODO: What could we test if loss.approx_hessian?
+ pass
+ else:
+ assert_allclose(h[:, k], h_numeric, rtol=5e-6, atol=1e-10)
+
+
+@pytest.mark.parametrize(
+ "loss, x0, y_true",
+ [
+ ("squared_error", -2.0, 42),
+ ("squared_error", 117.0, 1.05),
+ ("squared_error", 0.0, 0.0),
+ # The argmin of binomial_loss for y_true=0 and y_true=1 is resp.
+ # -inf and +inf due to logit, cf. "complete separation". Therefore, we
+ # use 0 < y_true < 1.
+ ("binomial_loss", 0.3, 0.1),
+ ("binomial_loss", -12, 0.2),
+ ("binomial_loss", 30, 0.9),
+ ("poisson_loss", 12.0, 1.0),
+ ("poisson_loss", 0.0, 2.0),
+ ("poisson_loss", -22.0, 10.0),
+ ],
+)
+@pytest.mark.skipif(
+ sp_version == parse_version("1.2.0"),
+ reason="bug in scipy 1.2.0, see scipy issue #9608",
+)
+@skip_if_32bit
+def test_derivatives(loss, x0, y_true):
+ """Test that gradients are zero at the minimum of the loss.
+
+ We check this on a single value/sample using Halley's method with the
+ first and second order derivatives computed by the Loss instance.
+ Note that methods of Loss instances operate on arrays while the newton
+ root finder expects a scalar or a one-element array for this purpose.
+ """
+ loss = _LOSSES[loss](sample_weight=None)
+ y_true = np.array([y_true], dtype=np.float64)
+ x0 = np.array([x0], dtype=np.float64)
+
+ def func(x: np.ndarray) -> np.ndarray:
+ """Compute loss plus constant term.
+
+ The constant term is such that the minimum function value is zero,
+ which is required by the Newton method.
+ """
+ return loss.loss(
+ y_true=y_true, raw_prediction=x
+ ) + loss.constant_to_optimal_zero(y_true=y_true)
+
+ def fprime(x: np.ndarray) -> np.ndarray:
+ return loss.gradient(y_true=y_true, raw_prediction=x)
+
+ def fprime2(x: np.ndarray) -> np.ndarray:
+ return loss.gradient_hessian(y_true=y_true, raw_prediction=x)[1]
+
+ optimum = newton(
+ func,
+ x0=x0,
+ fprime=fprime,
+ fprime2=fprime2,
+ maxiter=100,
+ tol=5e-8,
+ )
+
+ # Need to ravel arrays because assert_allclose requires matching
+ # dimensions.
+ y_true = y_true.ravel()
+ optimum = optimum.ravel()
+ assert_allclose(loss.link.inverse(optimum), y_true)
+ assert_allclose(func(optimum), 0, atol=1e-14)
+ assert_allclose(loss.gradient(y_true=y_true, raw_prediction=optimum), 0, atol=5e-7)
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+@pytest.mark.parametrize("sample_weight", [None, "range"])
+def test_loss_intercept_only(loss, sample_weight):
+ """Test that fit_intercept_only returns the argmin of the loss.
+
+ Also test that the gradient is zero at the minimum.
+ """
+ n_samples = 50
+ if not loss.is_multiclass:
+ y_true = loss.link.inverse(np.linspace(-4, 4, num=n_samples))
+ else:
+ y_true = np.arange(n_samples).astype(float) % loss.n_classes
+ y_true[::5] = 0 # exceedance of class 0
+
+ if sample_weight == "range":
+ sample_weight = np.linspace(0.1, 2, num=n_samples)
+
+ a = loss.fit_intercept_only(y_true=y_true, sample_weight=sample_weight)
+
+ # find minimum by optimization
+ def fun(x):
+ if not loss.is_multiclass:
+ raw_prediction = np.full(shape=(n_samples), fill_value=x)
+ else:
+ raw_prediction = np.ascontiguousarray(
+ np.broadcast_to(x, shape=(n_samples, loss.n_classes))
+ )
+ return loss(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=sample_weight,
+ )
+
+ if not loss.is_multiclass:
+ opt = minimize_scalar(fun, tol=1e-7, options={"maxiter": 100})
+ grad = loss.gradient(
+ y_true=y_true,
+ raw_prediction=np.full_like(y_true, a),
+ sample_weight=sample_weight,
+ )
+ assert a.shape == tuple() # scalar
+ assert a.dtype == y_true.dtype
+ assert_all_finite(a)
+ a == approx(opt.x, rel=1e-7)
+ grad.sum() == approx(0, abs=1e-12)
+ else:
+ # The constraint corresponds to sum(raw_prediction) = 0. Without it, we would
+ # need to apply loss.symmetrize_raw_prediction to opt.x before comparing.
+ # TODO: With scipy 1.1.0, one could use
+ # LinearConstraint(np.ones((1, loss.n_classes)), 0, 0)
+ opt = minimize(
+ fun,
+ np.zeros((loss.n_classes)),
+ tol=1e-13,
+ options={"maxiter": 100},
+ method="SLSQP",
+ constraints={
+ "type": "eq",
+ "fun": lambda x: np.ones((1, loss.n_classes)) @ x,
+ },
+ )
+ grad = loss.gradient(
+ y_true=y_true,
+ raw_prediction=np.tile(a, (n_samples, 1)),
+ sample_weight=sample_weight,
+ )
+ assert a.dtype == y_true.dtype
+ assert_all_finite(a)
+ assert_allclose(a, opt.x, rtol=5e-6, atol=1e-12)
+ assert_allclose(grad.sum(axis=0), 0, atol=1e-12)
+
+
+@pytest.mark.parametrize(
+ "loss, func, random_dist",
+ [
+ (HalfSquaredError(), np.mean, "normal"),
+ (AbsoluteError(), np.median, "normal"),
+ (PinballLoss(quantile=0.25), lambda x: np.percentile(x, q=25), "normal"),
+ (HalfPoissonLoss(), np.mean, "poisson"),
+ (HalfGammaLoss(), np.mean, "exponential"),
+ (HalfTweedieLoss(), np.mean, "exponential"),
+ (HalfBinomialLoss(), np.mean, "binomial"),
+ ],
+)
+def test_specific_fit_intercept_only(loss, func, random_dist):
+ """Test that fit_intercept_only returns the correct functional.
+
+ We test the functional for specific, meaningful distributions, e.g.
+ squared error estimates the expectation of a probability distribution.
+ """
+ rng = np.random.RandomState(0)
+ if random_dist == "binomial":
+ y_train = rng.binomial(1, 0.5, size=100)
+ else:
+ y_train = getattr(rng, random_dist)(size=100)
+ baseline_prediction = loss.fit_intercept_only(y_true=y_train)
+ # Make sure baseline prediction is the expected functional=func, e.g. mean
+ # or median.
+ assert_all_finite(baseline_prediction)
+ assert baseline_prediction == approx(loss.link.link(func(y_train)))
+ assert loss.link.inverse(baseline_prediction) == approx(func(y_train))
+ if isinstance(loss, IdentityLink):
+ assert_allclose(loss.link.inverse(baseline_prediction), baseline_prediction)
+
+ # Test baseline at boundary
+ if loss.interval_y_true.low_inclusive:
+ y_train.fill(loss.interval_y_true.low)
+ baseline_prediction = loss.fit_intercept_only(y_true=y_train)
+ assert_all_finite(baseline_prediction)
+ if loss.interval_y_true.high_inclusive:
+ y_train.fill(loss.interval_y_true.high)
+ baseline_prediction = loss.fit_intercept_only(y_true=y_train)
+ assert_all_finite(baseline_prediction)
+
+
+def test_multinomial_loss_fit_intercept_only():
+ """Test that fit_intercept_only returns the mean functional for CCE."""
+ rng = np.random.RandomState(0)
+ n_classes = 4
+ loss = HalfMultinomialLoss(n_classes=n_classes)
+ # Same logic as test_specific_fit_intercept_only. Here inverse link
+ # function = softmax and link function = log - symmetry term.
+ y_train = rng.randint(0, n_classes + 1, size=100).astype(np.float64)
+ baseline_prediction = loss.fit_intercept_only(y_true=y_train)
+ assert baseline_prediction.shape == (n_classes,)
+ p = np.zeros(n_classes, dtype=y_train.dtype)
+ for k in range(n_classes):
+ p[k] = (y_train == k).mean()
+ assert_allclose(baseline_prediction, np.log(p) - np.mean(np.log(p)))
+ assert_allclose(baseline_prediction[None, :], loss.link.link(p[None, :]))
+
+ for y_train in (np.zeros(shape=10), np.ones(shape=10)):
+ y_train = y_train.astype(np.float64)
+ baseline_prediction = loss.fit_intercept_only(y_true=y_train)
+ assert baseline_prediction.dtype == y_train.dtype
+ assert_all_finite(baseline_prediction)
+
+
+def test_binomial_and_multinomial_loss():
+ """Test that multinomial loss with n_classes = 2 is the same as binomial loss."""
+ rng = np.random.RandomState(0)
+ n_samples = 20
+ binom = HalfBinomialLoss()
+ multinom = HalfMultinomialLoss(n_classes=2)
+ y_train = rng.randint(0, 2, size=n_samples).astype(np.float64)
+ raw_prediction = rng.normal(size=n_samples)
+ raw_multinom = np.empty((n_samples, 2))
+ raw_multinom[:, 0] = -0.5 * raw_prediction
+ raw_multinom[:, 1] = 0.5 * raw_prediction
+ assert_allclose(
+ binom.loss(y_true=y_train, raw_prediction=raw_prediction),
+ multinom.loss(y_true=y_train, raw_prediction=raw_multinom),
+ )
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+def test_predict_proba(loss):
+ """Test that predict_proba and gradient_proba work as expected."""
+ n_samples = 20
+ y_true, raw_prediction = random_y_true_raw_prediction(
+ loss=loss,
+ n_samples=n_samples,
+ y_bound=(-100, 100),
+ raw_bound=(-5, 5),
+ seed=42,
+ )
+
+ if hasattr(loss, "predict_proba"):
+ proba = loss.predict_proba(raw_prediction)
+ assert proba.shape == (n_samples, loss.n_classes)
+ assert np.sum(proba, axis=1) == approx(1, rel=1e-11)
+
+ if hasattr(loss, "gradient_proba"):
+ for grad, proba in (
+ (None, None),
+ (None, np.empty_like(raw_prediction)),
+ (np.empty_like(raw_prediction), None),
+ (np.empty_like(raw_prediction), np.empty_like(raw_prediction)),
+ ):
+ grad, proba = loss.gradient_proba(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=None,
+ gradient_out=grad,
+ proba_out=proba,
+ )
+ assert proba.shape == (n_samples, loss.n_classes)
+ assert np.sum(proba, axis=1) == approx(1, rel=1e-11)
+ assert_allclose(
+ grad,
+ loss.gradient(
+ y_true=y_true,
+ raw_prediction=raw_prediction,
+ sample_weight=None,
+ gradient_out=None,
+ ),
+ )
+
+
+@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
+def test_loss_pickle(loss):
+ """Test that losses can be pickled."""
+ n_samples = 20
+ y_true, raw_prediction = random_y_true_raw_prediction(
+ loss=loss,
+ n_samples=n_samples,
+ y_bound=(-100, 100),
+ raw_bound=(-5, 5),
+ seed=42,
+ )
+ pickled_loss = pickle.dumps(loss)
+ unpickled_loss = pickle.loads(pickled_loss)
+ assert loss(y_true=y_true, raw_prediction=raw_prediction) == approx(
+ unpickled_loss(y_true=y_true, raw_prediction=raw_prediction)
+ )
diff --git a/sklearn/setup.py b/sklearn/setup.py
index f9d549c094ec2..874bdbbcbed43 100644
--- a/sklearn/setup.py
+++ b/sklearn/setup.py
@@ -48,12 +48,12 @@ def configuration(parent_package="", top_path=None):
config.add_subpackage("experimental/tests")
config.add_subpackage("ensemble/_hist_gradient_boosting")
config.add_subpackage("ensemble/_hist_gradient_boosting/tests")
- config.add_subpackage("_loss/")
- config.add_subpackage("_loss/tests")
config.add_subpackage("externals")
config.add_subpackage("externals/_packaging")
# submodules which have their own setup.py
+ config.add_subpackage("_loss")
+ config.add_subpackage("_loss/tests")
config.add_subpackage("cluster")
config.add_subpackage("datasets")
config.add_subpackage("decomposition")