From f6c3bdaa8a201fe4b90f7237d10d8ed7b90c6d26 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 21 Dec 2023 19:07:22 +0100 Subject: [PATCH 01/12] MNT replace Cython loss functions in SGD part 1 --- sklearn/linear_model/_sgd_fast.pxd | 16 +++---- sklearn/linear_model/_sgd_fast.pyx.tp | 66 +++++++++++++-------------- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/sklearn/linear_model/_sgd_fast.pxd b/sklearn/linear_model/_sgd_fast.pxd index 7ae704eee18db..243ce2e13a35e 100644 --- a/sklearn/linear_model/_sgd_fast.pxd +++ b/sklearn/linear_model/_sgd_fast.pxd @@ -7,20 +7,20 @@ cdef class LossFunction: cdef class Regression(LossFunction): - cdef double loss(self, double p, double y) noexcept nogil - cdef double dloss(self, double p, double y) noexcept nogil + cdef double loss(self, double y, double p) noexcept nogil + cdef double dloss(self, double y, double p) noexcept nogil cdef class Classification(LossFunction): - cdef double loss(self, double p, double y) noexcept nogil - cdef double dloss(self, double p, double y) noexcept nogil + cdef double loss(self, double y, double p) noexcept nogil + cdef double dloss(self, double y, double p) noexcept nogil cdef class Log(Classification): - cdef double loss(self, double p, double y) noexcept nogil - cdef double dloss(self, double p, double y) noexcept nogil + cdef double loss(self, double y, double p) noexcept nogil + cdef double dloss(self, double y, double p) noexcept nogil cdef class SquaredLoss(Regression): - cdef double loss(self, double p, double y) noexcept nogil - cdef double dloss(self, double p, double y) noexcept nogil + cdef double loss(self, double y, double p) noexcept nogil + cdef double dloss(self, double y, double p) noexcept nogil diff --git a/sklearn/linear_model/_sgd_fast.pyx.tp b/sklearn/linear_model/_sgd_fast.pyx.tp index bcd2bd7e5576e..b92d983a1b4b8 100644 --- a/sklearn/linear_model/_sgd_fast.pyx.tp +++ b/sklearn/linear_model/_sgd_fast.pyx.tp @@ -77,15 +77,15 @@ cdef extern from *: cdef class LossFunction: """Base class for convex loss functions""" - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: """Evaluate the loss function. Parameters ---------- - p : double - The prediction, `p = w^T x + intercept`. y : double The true value (aka target). + p : double + The prediction, `p = w^T x + intercept`. Returns ------- @@ -111,7 +111,7 @@ cdef class LossFunction: double The derivative of the loss function with regards to `p`. """ - return self.dloss(p, y) + return self.dloss(y, p) def py_loss(self, double p, double y): """Python version of `loss` for testing. @@ -130,18 +130,18 @@ cdef class LossFunction: double The loss evaluated at `p` and `y`. """ - return self.loss(p, y) + return self.loss(y, p) - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: """Evaluate the derivative of the loss function with respect to the prediction `p`. Parameters ---------- - p : double - The prediction, `p = w^T x`. y : double The true value (aka target). + p : double + The prediction, `p = w^T x`. Returns ------- @@ -154,20 +154,20 @@ cdef class LossFunction: cdef class Regression(LossFunction): """Base class for loss functions for regression""" - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: return 0. - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: return 0. cdef class Classification(LossFunction): """Base class for loss functions for classification""" - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: return 0. - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: return 0. @@ -179,7 +179,7 @@ cdef class ModifiedHuber(Classification): See T. Zhang 'Solving Large Scale Linear Prediction Problems Using Stochastic Gradient Descent', ICML'04. """ - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: cdef double z = p * y if z >= 1.0: return 0.0 @@ -188,7 +188,7 @@ cdef class ModifiedHuber(Classification): else: return -4.0 * z - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: cdef double z = p * y if z >= 1.0: return 0.0 @@ -217,13 +217,13 @@ cdef class Hinge(Classification): def __init__(self, double threshold=1.0): self.threshold = threshold - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: cdef double z = p * y if z <= self.threshold: return self.threshold - z return 0.0 - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: cdef double z = p * y if z <= self.threshold: return -y @@ -249,13 +249,13 @@ cdef class SquaredHinge(Classification): def __init__(self, double threshold=1.0): self.threshold = threshold - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: cdef double z = self.threshold - p * y if z > 0: return z * z return 0.0 - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: cdef double z = self.threshold - p * y if z > 0: return -2 * y * z @@ -268,7 +268,7 @@ cdef class SquaredHinge(Classification): cdef class Log(Classification): """Logistic regression loss for binary classification with y in {-1, 1}""" - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: cdef double z = p * y # approximately equal and saves the computation of the log if z > 18: @@ -277,7 +277,7 @@ cdef class Log(Classification): return -z return log(1.0 + exp(-z)) - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: cdef double z = p * y # approximately equal and saves the computation of the log if z > 18.0: @@ -292,10 +292,10 @@ cdef class Log(Classification): cdef class SquaredLoss(Regression): """Squared loss traditional used in linear regression.""" - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: return 0.5 * (p - y) * (p - y) - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: return p - y def __reduce__(self): @@ -316,7 +316,7 @@ cdef class Huber(Regression): def __init__(self, double c): self.c = c - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: cdef double r = p - y cdef double abs_r = fabs(r) if abs_r <= self.c: @@ -324,7 +324,7 @@ cdef class Huber(Regression): else: return self.c * abs_r - (0.5 * self.c * self.c) - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: cdef double r = p - y cdef double abs_r = fabs(r) if abs_r <= self.c: @@ -349,11 +349,11 @@ cdef class EpsilonInsensitive(Regression): def __init__(self, double epsilon): self.epsilon = epsilon - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: cdef double ret = fabs(y - p) - self.epsilon return ret if ret > 0 else 0 - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: if y - p > self.epsilon: return -1 elif p - y > self.epsilon: @@ -376,11 +376,11 @@ cdef class SquaredEpsilonInsensitive(Regression): def __init__(self, double epsilon): self.epsilon = epsilon - cdef double loss(self, double p, double y) noexcept nogil: + cdef double loss(self, double y, double p) noexcept nogil: cdef double ret = fabs(y - p) - self.epsilon return ret * ret if ret > 0 else 0 - cdef double dloss(self, double p, double y) noexcept nogil: + cdef double dloss(self, double y, double p) noexcept nogil: cdef double z z = y - p if z > self.epsilon: @@ -569,7 +569,7 @@ def _plain_sgd{{name_suffix}}( if learning_rate == OPTIMAL: typw = np.sqrt(1.0 / np.sqrt(alpha)) # computing eta0, the initial learning rate - initial_eta0 = typw / max(1.0, loss.dloss(-typw, 1.0)) + initial_eta0 = typw / max(1.0, loss.dloss(1.0, -typw)) # initialize t such that eta at first sample equals eta0 optimal_init = 1.0 / (initial_eta0 * alpha) @@ -598,7 +598,7 @@ def _plain_sgd{{name_suffix}}( eta = eta0 / pow(t, power_t) if verbose or not early_stopping: - sumloss += loss.loss(p, y) + sumloss += loss.loss(y, p) if y > 0.0: class_weight = weight_pos @@ -609,12 +609,12 @@ def _plain_sgd{{name_suffix}}( update = sqnorm(x_data_ptr, x_ind_ptr, xnnz) if update == 0: continue - update = min(C, loss.loss(p, y) / update) + update = min(C, loss.loss(y, p) / update) elif learning_rate == PA2: update = sqnorm(x_data_ptr, x_ind_ptr, xnnz) - update = loss.loss(p, y) / (update + 0.5 / C) + update = loss.loss(y, p) / (update + 0.5 / C) else: - dloss = loss.dloss(p, y) + dloss = loss.dloss(y, p) # clip dloss with large values to avoid numerical # instabilities if dloss < -MAX_DLOSS: From df4bdca428af98faa7813b38242957facb4043aa Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 21 Dec 2023 19:25:12 +0100 Subject: [PATCH 02/12] MNT change argument order in SAG --- sklearn/linear_model/_sag_fast.pyx.tp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 97bf3020d6602..9bfeed559bc13 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -85,7 +85,7 @@ cdef {{c_type}} _logsumexp{{name_suffix}}({{c_type}}* arr, int n_classes) noexce {{for name_suffix, c_type, np_type in dtypes}} cdef class MultinomialLogLoss{{name_suffix}}: - cdef {{c_type}} _loss(self, {{c_type}}* prediction, {{c_type}} y, int n_classes, + cdef {{c_type}} _loss(self, {{c_type}} y, {{c_type}}* prediction, int n_classes, {{c_type}} sample_weight) noexcept nogil: r"""Multinomial Logistic regression loss. @@ -100,12 +100,12 @@ cdef class MultinomialLogLoss{{name_suffix}}: Parameters ---------- - prediction : pointer to a np.ndarray[{{c_type}}] of shape (n_classes,) - Prediction of the multinomial classifier, for current sample. - y : {{c_type}}, between 0 and n_classes - 1 Indice of the correct class for current sample (i.e. label encoded). + prediction : pointer to a np.ndarray[{{c_type}}] of shape (n_classes,) + Prediction of the multinomial classifier, for current sample. + n_classes : integer Total number of classes. @@ -129,7 +129,7 @@ cdef class MultinomialLogLoss{{name_suffix}}: loss = (logsumexp_prediction - prediction[int(y)]) * sample_weight return loss - cdef void dloss(self, {{c_type}}* prediction, {{c_type}} y, int n_classes, + cdef void dloss(self, {{c_type}} y, {{c_type}}* prediction, int n_classes, {{c_type}} sample_weight, {{c_type}}* gradient_ptr) noexcept nogil: r"""Multinomial Logistic regression gradient of the loss. @@ -414,9 +414,9 @@ def sag{{name_suffix}}( # compute the gradient for this sample, given the prediction if multinomial: - multiloss.dloss(&prediction[0], y, n_classes, sample_weight, &gradient[0]) + multiloss.dloss(y, &prediction[0], n_classes, sample_weight, &gradient[0]) else: - gradient[0] = loss.dloss(prediction[0], y) * sample_weight + gradient[0] = loss.dloss(y, prediction[0]) * sample_weight # L2 regularization by simply rescaling the weights wscale *= wscale_update @@ -835,10 +835,10 @@ def _multinomial_grad_loss_all_samples( ) # compute the gradient for this sample, given the prediction - multiloss.dloss(&prediction[0], y, n_classes, sample_weight, &gradient[0]) + multiloss.dloss(y, &prediction[0], n_classes, sample_weight, &gradient[0]) # compute the loss for this sample, given the prediction - sum_loss += multiloss._loss(&prediction[0], y, n_classes, sample_weight) + sum_loss += multiloss._loss(y, &prediction[0], n_classes, sample_weight) # update the sum of the gradient for j in range(xnnz): From 7ddbcd584335f7a754b0c58efbef9c72d1183fc7 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 22 Dec 2023 14:06:17 +0100 Subject: [PATCH 03/12] MNT inherit from CyLossFunction mainly name changes: - loss(..) -> cy_loss(..) - dloss(..) -> cy_gradient(..) --- sklearn/linear_model/_sag_fast.pyx.tp | 16 +++---- sklearn/linear_model/_sgd_fast.pxd | 25 +++++------ sklearn/linear_model/_sgd_fast.pyx.tp | 65 ++++++++++++++------------- 3 files changed, 52 insertions(+), 54 deletions(-) diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 9bfeed559bc13..1274ffe90ddb2 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -30,7 +30,7 @@ import numpy as np from libc.math cimport fabs, exp, log from libc.time cimport time, time_t -from ._sgd_fast cimport LossFunction +from sklearn._loss._loss cimport CyLossFunction from ._sgd_fast cimport Log, SquaredLoss from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64 @@ -85,7 +85,7 @@ cdef {{c_type}} _logsumexp{{name_suffix}}({{c_type}}* arr, int n_classes) noexce {{for name_suffix, c_type, np_type in dtypes}} cdef class MultinomialLogLoss{{name_suffix}}: - cdef {{c_type}} _loss(self, {{c_type}} y, {{c_type}}* prediction, int n_classes, + cdef {{c_type}} cy_loss(self, {{c_type}} y, {{c_type}}* prediction, int n_classes, {{c_type}} sample_weight) noexcept nogil: r"""Multinomial Logistic regression loss. @@ -129,7 +129,7 @@ cdef class MultinomialLogLoss{{name_suffix}}: loss = (logsumexp_prediction - prediction[int(y)]) * sample_weight return loss - cdef void dloss(self, {{c_type}} y, {{c_type}}* prediction, int n_classes, + cdef void cy_gradient(self, {{c_type}} y, {{c_type}}* prediction, int n_classes, {{c_type}} sample_weight, {{c_type}}* gradient_ptr) noexcept nogil: r"""Multinomial Logistic regression gradient of the loss. @@ -339,7 +339,7 @@ def sag{{name_suffix}}( cdef bint prox = beta > 0 and saga # Loss function to optimize - cdef LossFunction loss + cdef CyLossFunction loss # Whether the loss function is multinomial cdef bint multinomial = False # Multinomial loss function @@ -414,9 +414,9 @@ def sag{{name_suffix}}( # compute the gradient for this sample, given the prediction if multinomial: - multiloss.dloss(y, &prediction[0], n_classes, sample_weight, &gradient[0]) + multiloss.cy_gradient(y, &prediction[0], n_classes, sample_weight, &gradient[0]) else: - gradient[0] = loss.dloss(y, prediction[0]) * sample_weight + gradient[0] = loss.cy_gradient(y, prediction[0]) * sample_weight # L2 regularization by simply rescaling the weights wscale *= wscale_update @@ -835,10 +835,10 @@ def _multinomial_grad_loss_all_samples( ) # compute the gradient for this sample, given the prediction - multiloss.dloss(y, &prediction[0], n_classes, sample_weight, &gradient[0]) + multiloss.cy_gradient(y, &prediction[0], n_classes, sample_weight, &gradient[0]) # compute the loss for this sample, given the prediction - sum_loss += multiloss._loss(y, &prediction[0], n_classes, sample_weight) + sum_loss += multiloss.cy_loss(y, &prediction[0], n_classes, sample_weight) # update the sum of the gradient for j in range(xnnz): diff --git a/sklearn/linear_model/_sgd_fast.pxd b/sklearn/linear_model/_sgd_fast.pxd index 243ce2e13a35e..597827f381883 100644 --- a/sklearn/linear_model/_sgd_fast.pxd +++ b/sklearn/linear_model/_sgd_fast.pxd @@ -1,26 +1,23 @@ # License: BSD 3 clause """Helper to load LossFunction from sgd_fast.pyx to sag_fast.pyx""" -cdef class LossFunction: - cdef double loss(self, double p, double y) noexcept nogil - cdef double dloss(self, double p, double y) noexcept nogil +from sklearn._loss._loss cimport CyLossFunction +cdef class Regression(CyLossFunction): + cdef double cy_loss(self, double y, double p) noexcept nogil + cdef double cy_gradient(self, double y, double p) noexcept nogil -cdef class Regression(LossFunction): - cdef double loss(self, double y, double p) noexcept nogil - cdef double dloss(self, double y, double p) noexcept nogil - -cdef class Classification(LossFunction): - cdef double loss(self, double y, double p) noexcept nogil - cdef double dloss(self, double y, double p) noexcept nogil +cdef class Classification(CyLossFunction): + cdef double cy_loss(self, double y, double p) noexcept nogil + cdef double cy_gradient(self, double y, double p) noexcept nogil cdef class Log(Classification): - cdef double loss(self, double y, double p) noexcept nogil - cdef double dloss(self, double y, double p) noexcept nogil + cdef double cy_loss(self, double y, double p) noexcept nogil + cdef double cy_gradient(self, double y, double p) noexcept nogil cdef class SquaredLoss(Regression): - cdef double loss(self, double y, double p) noexcept nogil - cdef double dloss(self, double y, double p) noexcept nogil + cdef double cy_loss(self, double y, double p) noexcept nogil + cdef double cy_gradient(self, double y, double p) noexcept nogil diff --git a/sklearn/linear_model/_sgd_fast.pyx.tp b/sklearn/linear_model/_sgd_fast.pyx.tp index b92d983a1b4b8..7795f40d5f17b 100644 --- a/sklearn/linear_model/_sgd_fast.pyx.tp +++ b/sklearn/linear_model/_sgd_fast.pyx.tp @@ -36,6 +36,7 @@ cdef extern from "_sgd_fast_helpers.h": bint skl_isfinite32(float) nogil bint skl_isfinite64(double) nogil +from .._loss._loss cimport CyLossFunction from ..utils._weight_vector cimport WeightVector32, WeightVector64 from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64 @@ -111,7 +112,7 @@ cdef class LossFunction: double The derivative of the loss function with regards to `p`. """ - return self.dloss(y, p) + return self.cy_gradient(y, p) def py_loss(self, double p, double y): """Python version of `loss` for testing. @@ -132,7 +133,7 @@ cdef class LossFunction: """ return self.loss(y, p) - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y, double p) noexcept nogil: """Evaluate the derivative of the loss function with respect to the prediction `p`. @@ -151,23 +152,23 @@ cdef class LossFunction: return 0. -cdef class Regression(LossFunction): +cdef class Regression(CyLossFunction): """Base class for loss functions for regression""" - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil: return 0. - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil: return 0. -cdef class Classification(LossFunction): +cdef class Classification(CyLossFunction): """Base class for loss functions for classification""" - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil: return 0. - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil: return 0. @@ -179,7 +180,7 @@ cdef class ModifiedHuber(Classification): See T. Zhang 'Solving Large Scale Linear Prediction Problems Using Stochastic Gradient Descent', ICML'04. """ - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y, double p) noexcept nogil: cdef double z = p * y if z >= 1.0: return 0.0 @@ -188,7 +189,7 @@ cdef class ModifiedHuber(Classification): else: return -4.0 * z - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y, double p) noexcept nogil: cdef double z = p * y if z >= 1.0: return 0.0 @@ -217,13 +218,13 @@ cdef class Hinge(Classification): def __init__(self, double threshold=1.0): self.threshold = threshold - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y, double p) noexcept nogil: cdef double z = p * y if z <= self.threshold: return self.threshold - z return 0.0 - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y, double p) noexcept nogil: cdef double z = p * y if z <= self.threshold: return -y @@ -249,13 +250,13 @@ cdef class SquaredHinge(Classification): def __init__(self, double threshold=1.0): self.threshold = threshold - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y, double p) noexcept nogil: cdef double z = self.threshold - p * y if z > 0: return z * z return 0.0 - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y, double p) noexcept nogil: cdef double z = self.threshold - p * y if z > 0: return -2 * y * z @@ -268,7 +269,7 @@ cdef class SquaredHinge(Classification): cdef class Log(Classification): """Logistic regression loss for binary classification with y in {-1, 1}""" - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y, double p) noexcept nogil: cdef double z = p * y # approximately equal and saves the computation of the log if z > 18: @@ -277,7 +278,7 @@ cdef class Log(Classification): return -z return log(1.0 + exp(-z)) - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y, double p) noexcept nogil: cdef double z = p * y # approximately equal and saves the computation of the log if z > 18.0: @@ -292,10 +293,10 @@ cdef class Log(Classification): cdef class SquaredLoss(Regression): """Squared loss traditional used in linear regression.""" - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y, double p) noexcept nogil: return 0.5 * (p - y) * (p - y) - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y, double p) noexcept nogil: return p - y def __reduce__(self): @@ -316,7 +317,7 @@ cdef class Huber(Regression): def __init__(self, double c): self.c = c - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y, double p) noexcept nogil: cdef double r = p - y cdef double abs_r = fabs(r) if abs_r <= self.c: @@ -324,7 +325,7 @@ cdef class Huber(Regression): else: return self.c * abs_r - (0.5 * self.c * self.c) - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y, double p) noexcept nogil: cdef double r = p - y cdef double abs_r = fabs(r) if abs_r <= self.c: @@ -349,11 +350,11 @@ cdef class EpsilonInsensitive(Regression): def __init__(self, double epsilon): self.epsilon = epsilon - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y, double p) noexcept nogil: cdef double ret = fabs(y - p) - self.epsilon return ret if ret > 0 else 0 - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y, double p) noexcept nogil: if y - p > self.epsilon: return -1 elif p - y > self.epsilon: @@ -376,11 +377,11 @@ cdef class SquaredEpsilonInsensitive(Regression): def __init__(self, double epsilon): self.epsilon = epsilon - cdef double loss(self, double y, double p) noexcept nogil: + cdef double cy_loss(self, double y, double p) noexcept nogil: cdef double ret = fabs(y - p) - self.epsilon return ret * ret if ret > 0 else 0 - cdef double dloss(self, double y, double p) noexcept nogil: + cdef double cy_gradient(self, double y, double p) noexcept nogil: cdef double z z = y - p if z > self.epsilon: @@ -400,7 +401,7 @@ def _plain_sgd{{name_suffix}}( double intercept, const {{c_type}}[::1] average_weights, double average_intercept, - LossFunction loss, + CyLossFunction loss, int penalty_type, double alpha, double C, @@ -439,8 +440,8 @@ def _plain_sgd{{name_suffix}}( is 0. average_intercept : double The average intercept for ASGD. Should be 0 if average is 0. - loss : LossFunction - A concrete ``LossFunction`` object. + loss : CyLossFunction + A concrete ``CyLossFunction`` object. penalty_type : int The penalty 2 for L2, 1 for L1, and 3 for Elastic-Net. alpha : float @@ -569,7 +570,7 @@ def _plain_sgd{{name_suffix}}( if learning_rate == OPTIMAL: typw = np.sqrt(1.0 / np.sqrt(alpha)) # computing eta0, the initial learning rate - initial_eta0 = typw / max(1.0, loss.dloss(1.0, -typw)) + initial_eta0 = typw / max(1.0, loss.cy_gradient(1.0, -typw)) # initialize t such that eta at first sample equals eta0 optimal_init = 1.0 / (initial_eta0 * alpha) @@ -598,7 +599,7 @@ def _plain_sgd{{name_suffix}}( eta = eta0 / pow(t, power_t) if verbose or not early_stopping: - sumloss += loss.loss(y, p) + sumloss += loss.cy_loss(y, p) if y > 0.0: class_weight = weight_pos @@ -609,12 +610,12 @@ def _plain_sgd{{name_suffix}}( update = sqnorm(x_data_ptr, x_ind_ptr, xnnz) if update == 0: continue - update = min(C, loss.loss(y, p) / update) + update = min(C, loss.cy_loss(y, p) / update) elif learning_rate == PA2: update = sqnorm(x_data_ptr, x_ind_ptr, xnnz) - update = loss.loss(y, p) / (update + 0.5 / C) + update = loss.cy_loss(y, p) / (update + 0.5 / C) else: - dloss = loss.dloss(y, p) + dloss = loss.cy_gradient(y, p) # clip dloss with large values to avoid numerical # instabilities if dloss < -MAX_DLOSS: From bc2be3c91ab3396fb55cb09e5884770b55a32ae7 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 22 Dec 2023 20:33:15 +0100 Subject: [PATCH 04/12] ENH replace Log, SquaredLoss and Huber with common losses --- sklearn/linear_model/__init__.py | 5 +- sklearn/linear_model/_logistic.py | 8 +-- sklearn/linear_model/_sag.py | 2 +- sklearn/linear_model/_sag_fast.pyx.tp | 13 ++-- sklearn/linear_model/_sgd_fast.pxd | 23 ------ sklearn/linear_model/_sgd_fast.pyx.tp | 73 -------------------- sklearn/linear_model/_stochastic_gradient.py | 26 ++++--- sklearn/linear_model/tests/test_sag.py | 6 +- 8 files changed, 29 insertions(+), 127 deletions(-) delete mode 100644 sklearn/linear_model/_sgd_fast.pxd diff --git a/sklearn/linear_model/__init__.py b/sklearn/linear_model/__init__.py index 45c99d4d36df1..a2d1a699dfc2b 100644 --- a/sklearn/linear_model/__init__.py +++ b/sklearn/linear_model/__init__.py @@ -43,7 +43,7 @@ from ._quantile import QuantileRegressor from ._ransac import RANSACRegressor from ._ridge import Ridge, RidgeClassifier, RidgeClassifierCV, RidgeCV, ridge_regression -from ._sgd_fast import Hinge, Huber, Log, ModifiedHuber, SquaredLoss +from ._sgd_fast import Hinge, ModifiedHuber from ._stochastic_gradient import SGDClassifier, SGDOneClassSVM, SGDRegressor from ._theil_sen import TheilSenRegressor @@ -53,7 +53,6 @@ "ElasticNet", "ElasticNetCV", "Hinge", - "Huber", "HuberRegressor", "Lars", "LarsCV", @@ -63,7 +62,6 @@ "LassoLarsCV", "LassoLarsIC", "LinearRegression", - "Log", "LogisticRegression", "LogisticRegressionCV", "ModifiedHuber", @@ -84,7 +82,6 @@ "SGDClassifier", "SGDRegressor", "SGDOneClassSVM", - "SquaredLoss", "TheilSenRegressor", "enet_path", "lars_path", diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 188204ce815ad..4ceca00dbaacb 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -318,14 +318,14 @@ def _logistic_regression_path( w0 = np.zeros(n_features + int(fit_intercept), dtype=X.dtype) mask = y == pos_class y_bin = np.ones(y.shape, dtype=X.dtype) - if solver in ["lbfgs", "newton-cg", "newton-cholesky"]: + if solver == "liblinear": + mask_classes = np.array([-1, 1]) + y_bin[~mask] = -1.0 + else: # HalfBinomialLoss, used for those solvers, represents y in [0, 1] instead # of in [-1, 1]. mask_classes = np.array([0, 1]) y_bin[~mask] = 0.0 - else: - mask_classes = np.array([-1, 1]) - y_bin[~mask] = -1.0 # for compute_class_weight if class_weight == "balanced": diff --git a/sklearn/linear_model/_sag.py b/sklearn/linear_model/_sag.py index 2626955ec2a7f..cca327c8879f6 100644 --- a/sklearn/linear_model/_sag.py +++ b/sklearn/linear_model/_sag.py @@ -128,7 +128,7 @@ def sag_solver( y : ndarray of shape (n_samples,) Target values. With loss='multinomial', y must be label encoded - (see preprocessing.LabelEncoder). + (see preprocessing.LabelEncoder). For loss='log' it must be in [0, 1]. sample_weight : array-like of shape (n_samples,), default=None Weights applied to individual samples (1. for unweighted). diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 1274ffe90ddb2..2abb8e5f3b9ae 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -29,14 +29,11 @@ dtypes = [('64', 'double', 'np.float64'), import numpy as np from libc.math cimport fabs, exp, log from libc.time cimport time, time_t +from libc.stdio cimport printf -from sklearn._loss._loss cimport CyLossFunction -from ._sgd_fast cimport Log, SquaredLoss - +from .._loss._loss cimport CyLossFunction, CyHalfSquaredError, CyHalfBinomialLoss from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64 -from libc.stdio cimport printf - {{for name_suffix, c_type, np_type in dtypes}} @@ -349,9 +346,9 @@ def sag{{name_suffix}}( multinomial = True multiloss = MultinomialLogLoss{{name_suffix}}() elif loss_function == "log": - loss = Log() + loss = CyHalfBinomialLoss() elif loss_function == "squared": - loss = SquaredLoss() + loss = CyHalfSquaredError() else: raise ValueError("Invalid loss parameter: got %s instead of " "one of ('log', 'squared', 'multinomial')" @@ -547,7 +544,7 @@ def sag{{name_suffix}}( (n_iter + 1, end_time - start_time)) break elif verbose: - printf('Epoch %d, change: %.8f\n', n_iter + 1, + printf('Epoch %d, change: %.8g\n', n_iter + 1, max_change / max_weight) n_iter += 1 # We do the error treatment here based on error code in status to avoid diff --git a/sklearn/linear_model/_sgd_fast.pxd b/sklearn/linear_model/_sgd_fast.pxd deleted file mode 100644 index 597827f381883..0000000000000 --- a/sklearn/linear_model/_sgd_fast.pxd +++ /dev/null @@ -1,23 +0,0 @@ -# License: BSD 3 clause -"""Helper to load LossFunction from sgd_fast.pyx to sag_fast.pyx""" - -from sklearn._loss._loss cimport CyLossFunction - -cdef class Regression(CyLossFunction): - cdef double cy_loss(self, double y, double p) noexcept nogil - cdef double cy_gradient(self, double y, double p) noexcept nogil - - -cdef class Classification(CyLossFunction): - cdef double cy_loss(self, double y, double p) noexcept nogil - cdef double cy_gradient(self, double y, double p) noexcept nogil - - -cdef class Log(Classification): - cdef double cy_loss(self, double y, double p) noexcept nogil - cdef double cy_gradient(self, double y, double p) noexcept nogil - - -cdef class SquaredLoss(Regression): - cdef double cy_loss(self, double y, double p) noexcept nogil - cdef double cy_gradient(self, double y, double p) noexcept nogil diff --git a/sklearn/linear_model/_sgd_fast.pyx.tp b/sklearn/linear_model/_sgd_fast.pyx.tp index 7795f40d5f17b..4e933b9be2857 100644 --- a/sklearn/linear_model/_sgd_fast.pyx.tp +++ b/sklearn/linear_model/_sgd_fast.pyx.tp @@ -266,79 +266,6 @@ cdef class SquaredHinge(Classification): return SquaredHinge, (self.threshold,) -cdef class Log(Classification): - """Logistic regression loss for binary classification with y in {-1, 1}""" - - cdef double cy_loss(self, double y, double p) noexcept nogil: - cdef double z = p * y - # approximately equal and saves the computation of the log - if z > 18: - return exp(-z) - if z < -18: - return -z - return log(1.0 + exp(-z)) - - cdef double cy_gradient(self, double y, double p) noexcept nogil: - cdef double z = p * y - # approximately equal and saves the computation of the log - if z > 18.0: - return exp(-z) * -y - if z < -18.0: - return -y - return -y / (exp(z) + 1.0) - - def __reduce__(self): - return Log, () - - -cdef class SquaredLoss(Regression): - """Squared loss traditional used in linear regression.""" - cdef double cy_loss(self, double y, double p) noexcept nogil: - return 0.5 * (p - y) * (p - y) - - cdef double cy_gradient(self, double y, double p) noexcept nogil: - return p - y - - def __reduce__(self): - return SquaredLoss, () - - -cdef class Huber(Regression): - """Huber regression loss - - Variant of the SquaredLoss that is robust to outliers (quadratic near zero, - linear in for large errors). - - https://en.wikipedia.org/wiki/Huber_Loss_Function - """ - - cdef double c - - def __init__(self, double c): - self.c = c - - cdef double cy_loss(self, double y, double p) noexcept nogil: - cdef double r = p - y - cdef double abs_r = fabs(r) - if abs_r <= self.c: - return 0.5 * r * r - else: - return self.c * abs_r - (0.5 * self.c * self.c) - - cdef double cy_gradient(self, double y, double p) noexcept nogil: - cdef double r = p - y - cdef double abs_r = fabs(r) - if abs_r <= self.c: - return r - elif r > 0.0: - return self.c - else: - return -self.c - - def __reduce__(self): - return Huber, (self.c,) - - cdef class EpsilonInsensitive(Regression): """Epsilon-Insensitive loss (used by SVR). diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index aeec7b5588add..d1c43800eb472 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -12,6 +12,7 @@ import numpy as np +from .._loss._loss import CyHalfBinomialLoss, CyHalfSquaredError, CyHuberLoss from ..base import ( BaseEstimator, OutlierMixin, @@ -33,12 +34,9 @@ from ._sgd_fast import ( EpsilonInsensitive, Hinge, - Huber, - Log, ModifiedHuber, SquaredEpsilonInsensitive, SquaredHinge, - SquaredLoss, _plain_sgd32, _plain_sgd64, ) @@ -334,13 +332,18 @@ def loss_function_(self): return self._loss_function_ -def _prepare_fit_binary(est, y, i, input_dtye): +def _prepare_fit_binary(est, y, i, input_dtye, label_encode=True): """Initialization for fit_binary. Returns y, coef, intercept, average_coef, average_intercept. """ y_i = np.ones(y.shape, dtype=input_dtye, order="C") - y_i[y != est.classes_[i]] = -1.0 + if label_encode: + # y in {0, 1} + y_i[y != est.classes_[i]] = 0.0 + else: + # y in {-1, +1} + y_i[y != est.classes_[i]] = -1.0 average_intercept = 0 average_coef = None @@ -433,8 +436,9 @@ def fit_binary( """ # if average is not true, average_coef, and average_intercept will be # unused + label_encode = isinstance(est._loss_function_, CyHalfBinomialLoss) y_i, coef, intercept, average_coef, average_intercept = _prepare_fit_binary( - est, y, i, input_dtye=X.dtype + est, y, i, input_dtye=X.dtype, label_encode=label_encode ) assert y_i.shape[0] == y.shape[0] == sample_weight.shape[0] @@ -510,10 +514,10 @@ class BaseSGDClassifier(LinearClassifierMixin, BaseSGD, metaclass=ABCMeta): "hinge": (Hinge, 1.0), "squared_hinge": (SquaredHinge, 1.0), "perceptron": (Hinge, 0.0), - "log_loss": (Log,), + "log_loss": (CyHalfBinomialLoss,), "modified_huber": (ModifiedHuber,), - "squared_error": (SquaredLoss,), - "huber": (Huber, DEFAULT_EPSILON), + "squared_error": (CyHalfSquaredError,), + "huber": (CyHuberLoss, DEFAULT_EPSILON), "epsilon_insensitive": (EpsilonInsensitive, DEFAULT_EPSILON), "squared_epsilon_insensitive": (SquaredEpsilonInsensitive, DEFAULT_EPSILON), } @@ -1379,8 +1383,8 @@ def _more_tags(self): class BaseSGDRegressor(RegressorMixin, BaseSGD): loss_functions = { - "squared_error": (SquaredLoss,), - "huber": (Huber, DEFAULT_EPSILON), + "squared_error": (CyHalfSquaredError,), + "huber": (CyHuberLoss, DEFAULT_EPSILON), "epsilon_insensitive": (EpsilonInsensitive, DEFAULT_EPSILON), "squared_epsilon_insensitive": (SquaredEpsilonInsensitive, DEFAULT_EPSILON), } diff --git a/sklearn/linear_model/tests/test_sag.py b/sklearn/linear_model/tests/test_sag.py index 96f8a79726833..b1f675a54196e 100644 --- a/sklearn/linear_model/tests/test_sag.py +++ b/sklearn/linear_model/tests/test_sag.py @@ -255,7 +255,7 @@ def get_step_size(X, alpha, fit_intercept, classification=True): def test_classifier_matching(): n_samples = 20 X, y = make_blobs(n_samples=n_samples, centers=2, random_state=0, cluster_std=0.1) - y[y == 0] = -1 + # y must be 0 or 1 alpha = 1.1 fit_intercept = True step_size = get_step_size(X, alpha, fit_intercept) @@ -278,7 +278,7 @@ def test_classifier_matching(): weights, intercept = sag_sparse( X, - y, + 2 * y - 1, # y must be -1 or +1 step_size, alpha, n_iter=n_iter, @@ -288,7 +288,7 @@ def test_classifier_matching(): ) weights2, intercept2 = sag( X, - y, + 2 * y - 1, # y must be -1 or +1 step_size, alpha, n_iter=n_iter, From 68f35c3e15c2d4fee2b14eedaf5eafe14baf2458 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 23 Dec 2023 14:51:09 +0100 Subject: [PATCH 05/12] MNT remove SGD extension type LossFunction --- sklearn/_loss/tests/test_loss.py | 53 ++++++++++++------ sklearn/linear_model/_sgd_fast.pyx.tp | 75 ++++++-------------------- sklearn/linear_model/tests/test_sgd.py | 50 ----------------- 3 files changed, 53 insertions(+), 125 deletions(-) diff --git a/sklearn/_loss/tests/test_loss.py b/sklearn/_loss/tests/test_loss.py index c018bb7147ce9..b3cea8eb9f452 100644 --- a/sklearn/_loss/tests/test_loss.py +++ b/sklearn/_loss/tests/test_loss.py @@ -224,48 +224,69 @@ def test_loss_boundary_y_pred(loss, y_pred_success, y_pred_fail): @pytest.mark.parametrize( - "loss, y_true, raw_prediction, loss_true", + "loss, y_true, raw_prediction, loss_true, gradient_true", [ - (HalfSquaredError(), 1.0, 5.0, 8), - (AbsoluteError(), 1.0, 5.0, 4), - (PinballLoss(quantile=0.5), 1.0, 5.0, 2), - (PinballLoss(quantile=0.25), 1.0, 5.0, 4 * (1 - 0.25)), - (PinballLoss(quantile=0.25), 5.0, 1.0, 4 * 0.25), - (HuberLoss(quantile=0.5, delta=3), 1.0, 5.0, 3 * (4 - 3 / 2)), - (HuberLoss(quantile=0.5, delta=3), 1.0, 3.0, 0.5 * 2**2), - (HalfPoissonLoss(), 2.0, np.log(4), 4 - 2 * np.log(4)), - (HalfGammaLoss(), 2.0, np.log(4), np.log(4) + 2 / 4), - (HalfTweedieLoss(power=3), 2.0, np.log(4), -1 / 4 + 1 / 4**2), - (HalfTweedieLossIdentity(power=1), 2.0, 4.0, 2 - 2 * np.log(2)), - (HalfTweedieLossIdentity(power=2), 2.0, 4.0, np.log(2) - 1 / 2), - (HalfTweedieLossIdentity(power=3), 2.0, 4.0, -1 / 4 + 1 / 4**2 + 1 / 2 / 2), - (HalfBinomialLoss(), 0.25, np.log(4), np.log(5) - 0.25 * np.log(4)), + (HalfSquaredError(), 1.0, 5.0, 8, 4.0), + (AbsoluteError(), 1.0, 5.0, 4, 1.0), + (PinballLoss(quantile=0.5), 1.0, 5.0, 2, None), + (PinballLoss(quantile=0.25), 1.0, 5.0, 4 * (1 - 0.25), None), + (PinballLoss(quantile=0.25), 5.0, 1.0, 4 * 0.25, None), + (HuberLoss(quantile=0.5, delta=0.1), 0.0, 0.0, 0.0, 0.0), + (HuberLoss(quantile=0.5, delta=0.1), 0.0, 0.1, 0.005, 0.1), + (HuberLoss(quantile=0.5, delta=0.1), 0.1, 0.0, 0.005, -0.1), + (HuberLoss(quantile=0.5, delta=0.1), 4.0, 3.95, 0.00125, -0.05), + (HuberLoss(quantile=0.5, delta=0.1), 2.0, 5.0, 0.295, 0.1), + (HuberLoss(quantile=0.5, delta=0.1), 5.0, -1.0, 0.595, -0.1), + (HuberLoss(quantile=0.5, delta=3), 1.0, 5.0, 3 * (4 - 3 / 2), None), + (HuberLoss(quantile=0.5, delta=3), 1.0, 3.0, 0.5 * 2**2, None), + (HalfPoissonLoss(), 2.0, np.log(4), 4 - 2 * np.log(4), None), + (HalfGammaLoss(), 2.0, np.log(4), np.log(4) + 2 / 4, None), + (HalfTweedieLoss(power=3), 2.0, np.log(4), -1 / 4 + 1 / 4**2, None), + (HalfTweedieLossIdentity(power=1), 2.0, 4.0, 2 - 2 * np.log(2), None), + (HalfTweedieLossIdentity(power=2), 2.0, 4.0, np.log(2) - 1 / 2, None), + ( + HalfTweedieLossIdentity(power=3), + 2.0, + 4.0, + -1 / 4 + 1 / 4**2 + 1 / 2 / 2, + None, + ), + (HalfBinomialLoss(), 0.25, np.log(4), np.log(5) - 0.25 * np.log(4), None), ( HalfMultinomialLoss(n_classes=3), 0.0, [0.2, 0.5, 0.3], logsumexp([0.2, 0.5, 0.3]) - 0.2, + None, ), ( HalfMultinomialLoss(n_classes=3), 1.0, [0.2, 0.5, 0.3], logsumexp([0.2, 0.5, 0.3]) - 0.5, + None, ), ( HalfMultinomialLoss(n_classes=3), 2.0, [0.2, 0.5, 0.3], logsumexp([0.2, 0.5, 0.3]) - 0.3, + None, ), ], ids=loss_instance_name, ) -def test_loss_on_specific_values(loss, y_true, raw_prediction, loss_true): +def test_loss_on_specific_values( + loss, y_true, raw_prediction, loss_true, gradient_true +): """Test losses at specific values.""" assert loss( y_true=np.array([y_true]), raw_prediction=np.array([raw_prediction]) ) == approx(loss_true, rel=1e-11, abs=1e-12) + if gradient_true is not None: + assert loss.gradient( + y_true=np.array([y_true]), raw_prediction=np.array([raw_prediction]) + ) == approx(gradient_true, rel=1e-11, abs=1e-12) @pytest.mark.parametrize("loss", ALL_LOSSES) diff --git a/sklearn/linear_model/_sgd_fast.pyx.tp b/sklearn/linear_model/_sgd_fast.pyx.tp index 4e933b9be2857..797ea11b7d736 100644 --- a/sklearn/linear_model/_sgd_fast.pyx.tp +++ b/sklearn/linear_model/_sgd_fast.pyx.tp @@ -75,28 +75,30 @@ cdef extern from *: # Extension Types for Loss Functions # ---------------------------------------- -cdef class LossFunction: - """Base class for convex loss functions""" +cdef class Regression(CyLossFunction): + """Base class for loss functions for regression""" - cdef double loss(self, double y, double p) noexcept nogil: - """Evaluate the loss function. + def py_loss(self, double p, double y): + """Python version of `loss` for testing only. + + Pytest needs a python function and can't use cdef functions. Parameters ---------- - y : double - The true value (aka target). p : double The prediction, `p = w^T x + intercept`. + y : double + The true value (aka target). Returns ------- double The loss evaluated at `p` and `y`. """ - return 0. + return self.cy_loss(y, p) def py_dloss(self, double p, double y): - """Python version of `dloss` for testing. + """Python version of `dloss` for testing only. Pytest needs a python function and can't use cdef functions. @@ -114,62 +116,17 @@ cdef class LossFunction: """ return self.cy_gradient(y, p) - def py_loss(self, double p, double y): - """Python version of `loss` for testing. - - Pytest needs a python function and can't use cdef functions. - - Parameters - ---------- - p : double - The prediction, `p = w^T x + intercept`. - y : double - The true value (aka target). - - Returns - ------- - double - The loss evaluated at `p` and `y`. - """ - return self.loss(y, p) - - cdef double cy_gradient(self, double y, double p) noexcept nogil: - """Evaluate the derivative of the loss function with respect to - the prediction `p`. - - Parameters - ---------- - y : double - The true value (aka target). - p : double - The prediction, `p = w^T x`. - - Returns - ------- - double - The derivative of the loss function with regards to `p`. - """ - return 0. - - -cdef class Regression(CyLossFunction): - """Base class for loss functions for regression""" - - cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil: - return 0. - - cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil: - return 0. - cdef class Classification(CyLossFunction): """Base class for loss functions for classification""" - cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil: - return 0. + def py_loss(self, double p, double y): + """Python version of `loss` for testing only.""" + return self.cy_loss(y, p) - cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil: - return 0. + def py_dloss(self, double p, double y): + """Python version of `dloss` for testing only.""" + return self.cy_gradient(y, p) cdef class ModifiedHuber(Classification): diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index d1dd1ca960f86..e7a2e15202ef0 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -1897,56 +1897,6 @@ def test_gradient_squared_hinge(): _test_loss_common(loss, cases) -def test_loss_log(): - # Test Log (logistic loss) - loss = sgd_fast.Log() - cases = [ - # (p, y, expected_loss, expected_dloss) - (1.0, 1.0, np.log(1.0 + np.exp(-1.0)), -1.0 / (np.exp(1.0) + 1.0)), - (1.0, -1.0, np.log(1.0 + np.exp(1.0)), 1.0 / (np.exp(-1.0) + 1.0)), - (-1.0, -1.0, np.log(1.0 + np.exp(-1.0)), 1.0 / (np.exp(1.0) + 1.0)), - (-1.0, 1.0, np.log(1.0 + np.exp(1.0)), -1.0 / (np.exp(-1.0) + 1.0)), - (0.0, 1.0, np.log(2), -0.5), - (0.0, -1.0, np.log(2), 0.5), - (17.9, -1.0, 17.9, 1.0), - (-17.9, 1.0, 17.9, -1.0), - ] - _test_loss_common(loss, cases) - assert_almost_equal(loss.py_dloss(18.1, 1.0), np.exp(-18.1) * -1.0, 16) - assert_almost_equal(loss.py_loss(18.1, 1.0), np.exp(-18.1), 16) - assert_almost_equal(loss.py_dloss(-18.1, -1.0), np.exp(-18.1) * 1.0, 16) - assert_almost_equal(loss.py_loss(-18.1, 1.0), 18.1, 16) - - -def test_loss_squared_loss(): - # Test SquaredLoss - loss = sgd_fast.SquaredLoss() - cases = [ - # (p, y, expected_loss, expected_dloss) - (0.0, 0.0, 0.0, 0.0), - (1.0, 1.0, 0.0, 0.0), - (1.0, 0.0, 0.5, 1.0), - (0.5, -1.0, 1.125, 1.5), - (-2.5, 2.0, 10.125, -4.5), - ] - _test_loss_common(loss, cases) - - -def test_loss_huber(): - # Test Huber - loss = sgd_fast.Huber(0.1) - cases = [ - # (p, y, expected_loss, expected_dloss) - (0.0, 0.0, 0.0, 0.0), - (0.1, 0.0, 0.005, 0.1), - (0.0, 0.1, 0.005, -0.1), - (3.95, 4.0, 0.00125, -0.05), - (5.0, 2.0, 0.295, 0.1), - (-1.0, 5.0, 0.595, -0.1), - ] - _test_loss_common(loss, cases) - - def test_loss_modified_huber(): # (p, y, expected_loss, expected_dloss) loss = sgd_fast.ModifiedHuber() From aefb44f4bd651c98d6285018f62e15f8caf709fa Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Tue, 26 Dec 2023 12:50:37 +0100 Subject: [PATCH 06/12] MNT remove Hing and ModifiedHuber from __init__ --- sklearn/linear_model/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/linear_model/__init__.py b/sklearn/linear_model/__init__.py index a2d1a699dfc2b..6ea0c612e7090 100644 --- a/sklearn/linear_model/__init__.py +++ b/sklearn/linear_model/__init__.py @@ -43,7 +43,6 @@ from ._quantile import QuantileRegressor from ._ransac import RANSACRegressor from ._ridge import Ridge, RidgeClassifier, RidgeClassifierCV, RidgeCV, ridge_regression -from ._sgd_fast import Hinge, ModifiedHuber from ._stochastic_gradient import SGDClassifier, SGDOneClassSVM, SGDRegressor from ._theil_sen import TheilSenRegressor @@ -52,7 +51,6 @@ "BayesianRidge", "ElasticNet", "ElasticNetCV", - "Hinge", "HuberRegressor", "Lars", "LarsCV", @@ -64,7 +62,6 @@ "LinearRegression", "LogisticRegression", "LogisticRegressionCV", - "ModifiedHuber", "MultiTaskElasticNet", "MultiTaskElasticNetCV", "MultiTaskLasso", From 51954fef0977fe4409dd030a1b7f8735e7e6c99d Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sat, 30 Dec 2023 23:35:41 +0100 Subject: [PATCH 07/12] ENH use CyHalfMultinomialLoss in SAGA --- sklearn/_loss/_loss.pxd | 10 ++ sklearn/_loss/_loss.pyx.tp | 185 ++++++++++++++++++-------- sklearn/_loss/tests/test_loss.py | 30 +++++ sklearn/linear_model/_sag_fast.pyx.tp | 18 ++- 4 files changed, 187 insertions(+), 56 deletions(-) diff --git a/sklearn/_loss/_loss.pxd b/sklearn/_loss/_loss.pxd index f38cbe0badc96..ac01b122a0941 100644 --- a/sklearn/_loss/_loss.pxd +++ b/sklearn/_loss/_loss.pxd @@ -89,3 +89,13 @@ cdef class CyExponentialLoss(CyLossFunction): cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) noexcept nogil + + +cdef class CyHalfMultinomialLoss(): + cdef void cy_gradient( + self, + const floating_in y_true, + const floating_in[::1] raw_prediction, + const floating_in sample_weight, + floating_out[::1] gradient_out, + ) noexcept nogil diff --git a/sklearn/_loss/_loss.pyx.tp b/sklearn/_loss/_loss.pyx.tp index 0ce653de84310..715551b5b570c 100644 --- a/sklearn/_loss/_loss.pyx.tp +++ b/sklearn/_loss/_loss.pyx.tp @@ -266,20 +266,21 @@ cdef inline double log1pexp(double x) noexcept nogil: return x -cdef inline void sum_exp_minus_max( +cdef inline double_pair sum_exp_minus_max( const int i, const floating_in[:, :] raw_prediction, # IN - floating_in *p # OUT + floating_out *p # OUT ) noexcept nogil: - # Thread local buffers are used to stores results of this function via p. + # Thread local buffers are used to store part of the results via p. # The results are stored as follows: # p[k] = exp(raw_prediction_i_k - max_value) for k = 0 to n_classes-1 - # p[-2] = max(raw_prediction_i_k, k = 0 to n_classes-1) - # p[-1] = sum(p[k], k = 0 to n_classes-1) = sum of exponentials - # len(p) must be n_classes + 2 + # return.val1 = max_value = max(raw_prediction_i_k, k = 0 to n_classes-1) + # return.val2 = sum_exps = sum(p[k], k = 0 to n_classes-1) = sum of exponentials + # len(p) must be n_classes # Notes: - # - Using "by reference" arguments doesn't work well, therefore we use a - # longer p, see https://github.com/cython/cython/issues/1863 + # - Using "by reference" arguments doesn't work well, see + # https://github.com/cython/cython/issues/1863 + # Therefore we return a double_pair and also store in p. # - i needs to be passed (and stays constant) because otherwise Cython does # not generate optimal code, see # https://github.com/scikit-learn/scikit-learn/issues/17299 @@ -288,19 +289,20 @@ cdef inline void sum_exp_minus_max( cdef: int k int n_classes = raw_prediction.shape[1] - double max_value = raw_prediction[i, 0] - double sum_exps = 0 + double_pair max_value_and_sum_exps # va1 = max_value, val2 = sum_exps + + max_value_and_sum_exps.val1 = raw_prediction[i, 0] + max_value_and_sum_exps.val2 = 0 for k in range(1, n_classes): # Compute max value of array for numerical stability - if max_value < raw_prediction[i, k]: - max_value = raw_prediction[i, k] + if max_value_and_sum_exps.val1 < raw_prediction[i, k]: + max_value_and_sum_exps.val1 = raw_prediction[i, k] for k in range(n_classes): - p[k] = exp(raw_prediction[i, k] - max_value) - sum_exps += p[k] + p[k] = exp(raw_prediction[i, k] - max_value_and_sum_exps.val1) + max_value_and_sum_exps.val2 += p[k] - p[n_classes] = max_value # same as p[-2] - p[n_classes + 1] = sum_exps # same as p[-1] + return max_value_and_sum_exps # ------------------------------------- @@ -1116,8 +1118,10 @@ cdef class {{name}}(CyLossFunction): # The multinomial deviance loss is also known as categorical cross-entropy or -# multinomial log-likelihood -cdef class CyHalfMultinomialLoss(CyLossFunction): +# multinomial log-likelihood. +# Here, we do not inherit from CyLossFunction as it's cy_gradient method deviates +# from the API. +cdef class CyHalfMultinomialLoss(): """Half Multinomial deviance loss with multinomial logit link. Domain: @@ -1131,6 +1135,78 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): mapped to (y_true == k) for k = 0 .. n_classes - 1 which is either 0 or 1. """ + # Here we deviate from the CyLossFunction API. SAG/SAGA needs direct access to + # sample-wise gradients which we provide here. + cdef inline void cy_gradient( + self, + const floating_in y_true, + const floating_in[::1] raw_prediction, # IN + const floating_in sample_weight, + floating_out[::1] gradient_out, # OUT + ) noexcept nogil: + """Compute gradient of loss w.r.t. raw_prediction for a single sample. + + The gradient of the multinomial logistic loss with respect to a class k, + and for one sample is: + grad_k = - sw * (p[k] - (y==k)) + + where: + p[k] = proba[k] = exp(raw_prediction[k] - logsumexp(raw_prediction)) + sw = sample_weight + + Parameters + ---------- + y_true : double + Observed, true target value. + raw_prediction : array of shape (n_classes,) + Raw prediction values (in link space). + sample_weight : double + Sample weight. + gradient_out : array of shape (n_classs,) + A location into which the gradient is stored. + + Returns + ------- + gradient : double + The derivative of the loss function w.r.t. `raw_prediction`. + """ + cdef: + int k + int n_classes = raw_prediction.shape[0] + double_pair max_value_and_sum_exps + const floating_in[:, :] raw = raw_prediction[None, :] + + max_value_and_sum_exps = sum_exp_minus_max(0, raw, &gradient_out[0]) + for k in range(n_classes): + # gradient_out[k] = p_k = y_pred_k = prob of class k + gradient_out[k] /= max_value_and_sum_exps.val2 + # gradient_k = (p_k - (y_true == k)) * sw + gradient_out[k] = (gradient_out[k] - (y_true == k)) * sample_weight + + def _test_cy_gradient( + self, + const floating_in[::1] y_true, # IN + const floating_in[:, ::1] raw_prediction, # IN + const floating_in[::1] sample_weight, # IN + ): + """For testing only.""" + cdef: + int i, k + int n_samples = y_true.shape[0] + int n_classes = raw_prediction.shape[1] + floating_in [:, ::1] gradient_out + gradient = np.empty((n_samples, n_classes), dtype=np.float64) + gradient_out = gradient + + for i in range(n_samples): + self.cy_gradient( + y_true=y_true[i], + raw_prediction=raw_prediction[i, :], + sample_weight=1.0 if sample_weight is None else sample_weight[i], + gradient_out=gradient_out[i, :], + ) + return gradient + # Note that we do not assume memory alignment/contiguity of 2d arrays. # There seems to be little benefit in doing so. Benchmarks proofing the # opposite are welcome. @@ -1148,6 +1224,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): int n_classes = raw_prediction.shape[1] floating_in max_value, sum_exps floating_in* p # temporary buffer + double_pair max_value_and_sum_exps # We assume n_samples > n_classes. In this case having the inner loop # over n_classes is a good default. @@ -1159,12 +1236,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): with nogil, parallel(num_threads=n_threads): # Define private buffer variables as each thread might use its # own. - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - max_value = p[n_classes] # p[-2] - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + max_value = max_value_and_sum_exps.val1 + sum_exps = max_value_and_sum_exps.val2 loss_out[i] = log(sum_exps) + max_value for k in range(n_classes): @@ -1175,12 +1252,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): free(p) else: with nogil, parallel(num_threads=n_threads): - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - max_value = p[n_classes] # p[-2] - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + max_value = max_value_and_sum_exps.val1 + sum_exps = max_value_and_sum_exps.val2 loss_out[i] = log(sum_exps) + max_value for k in range(n_classes): @@ -1207,18 +1284,19 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): int n_classes = raw_prediction.shape[1] floating_in max_value, sum_exps floating_in* p # temporary buffer + double_pair max_value_and_sum_exps if sample_weight is None: # inner loop over n_classes with nogil, parallel(num_threads=n_threads): # Define private buffer variables as each thread might use its # own. - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - max_value = p[n_classes] # p[-2] - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + max_value = max_value_and_sum_exps.val1 + sum_exps = max_value_and_sum_exps.val2 loss_out[i] = log(sum_exps) + max_value for k in range(n_classes): @@ -1232,12 +1310,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): free(p) else: with nogil, parallel(num_threads=n_threads): - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - max_value = p[n_classes] # p[-2] - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + max_value = max_value_and_sum_exps.val1 + sum_exps = max_value_and_sum_exps.val2 loss_out[i] = log(sum_exps) + max_value for k in range(n_classes): @@ -1266,17 +1344,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): int n_classes = raw_prediction.shape[1] floating_in sum_exps floating_in* p # temporary buffer + double_pair max_value_and_sum_exps if sample_weight is None: # inner loop over n_classes with nogil, parallel(num_threads=n_threads): # Define private buffer variables as each thread might use its # own. - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + sum_exps = max_value_and_sum_exps.val2 for k in range(n_classes): p[k] /= sum_exps # p_k = y_pred_k = prob of class k @@ -1286,11 +1365,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): free(p) else: with nogil, parallel(num_threads=n_threads): - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + sum_exps = max_value_and_sum_exps.val2 for k in range(n_classes): p[k] /= sum_exps # p_k = y_pred_k = prob of class k @@ -1314,17 +1393,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): int n_classes = raw_prediction.shape[1] floating_in sum_exps floating_in* p # temporary buffer + double_pair max_value_and_sum_exps if sample_weight is None: # inner loop over n_classes with nogil, parallel(num_threads=n_threads): # Define private buffer variables as each thread might use its # own. - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + sum_exps = max_value_and_sum_exps.val2 for k in range(n_classes): p[k] /= sum_exps # p_k = y_pred_k = prob of class k @@ -1336,11 +1416,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): free(p) else: with nogil, parallel(num_threads=n_threads): - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + sum_exps = max_value_and_sum_exps.val2 for k in range(n_classes): p[k] /= sum_exps # p_k = y_pred_k = prob of class k @@ -1369,17 +1449,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): int n_classes = raw_prediction.shape[1] floating_in sum_exps floating_in* p # temporary buffer + double_pair max_value_and_sum_exps if sample_weight is None: # inner loop over n_classes with nogil, parallel(num_threads=n_threads): # Define private buffer variables as each thread might use its # own. - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + sum_exps = max_value_and_sum_exps.val2 for k in range(n_classes): proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k @@ -1389,11 +1470,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction): free(p) else: with nogil, parallel(num_threads=n_threads): - p = malloc(sizeof(floating_in) * (n_classes + 2)) + p = malloc(sizeof(floating_in) * (n_classes)) for i in prange(n_samples, schedule='static'): - sum_exp_minus_max(i, raw_prediction, p) - sum_exps = p[n_classes + 1] # p[-1] + max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p) + sum_exps = max_value_and_sum_exps.val2 for k in range(n_classes): proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k diff --git a/sklearn/_loss/tests/test_loss.py b/sklearn/_loss/tests/test_loss.py index b3cea8eb9f452..ea3931b00c883 100644 --- a/sklearn/_loss/tests/test_loss.py +++ b/sklearn/_loss/tests/test_loss.py @@ -985,6 +985,36 @@ def test_multinomial_loss_fit_intercept_only(): assert_all_finite(baseline_prediction) +def test_multinomial_cy_gradient(global_random_seed): + """Test that Multinomial cy_gradient gives the same as gradient. + + CyHalfMultinomialLoss does not inherit from CyLossFunction and has a different API. + As a consequence, the functions like `loss` and `gradient` do no rely on `cy_loss` + and cy_gradient. + """ + n_samples = 100 + n_classes = 5 + loss = HalfMultinomialLoss(n_classes=n_classes) + y_true, raw_prediction = random_y_true_raw_prediction( + loss=loss, + n_samples=n_samples, + seed=global_random_seed, + ) + sample_weight = np.linspace(0.1, 2, num=n_samples) + + grad1 = loss.closs._test_cy_gradient( + y_true=y_true, + raw_prediction=raw_prediction, # needs to be C-contiguous + sample_weight=sample_weight, + ) + grad2 = loss.gradient( + y_true=y_true, + raw_prediction=raw_prediction, + sample_weight=sample_weight, + ) + assert_allclose(grad1, grad2) + + def test_binomial_and_multinomial_loss(global_random_seed): """Test that multinomial loss with n_classes = 2 is the same as binomial loss.""" rng = np.random.RandomState(global_random_seed) diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 2abb8e5f3b9ae..4cbd208f4e778 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -31,7 +31,12 @@ from libc.math cimport fabs, exp, log from libc.time cimport time, time_t from libc.stdio cimport printf -from .._loss._loss cimport CyLossFunction, CyHalfSquaredError, CyHalfBinomialLoss +from .._loss._loss cimport ( + CyLossFunction, + CyHalfBinomialLoss, + CyHalfMultinomialLoss, + CyHalfSquaredError, +) from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64 @@ -340,11 +345,11 @@ def sag{{name_suffix}}( # Whether the loss function is multinomial cdef bint multinomial = False # Multinomial loss function - cdef MultinomialLogLoss{{name_suffix}} multiloss + cdef CyHalfMultinomialLoss multiloss if loss_function == "multinomial": multinomial = True - multiloss = MultinomialLogLoss{{name_suffix}}() + multiloss = CyHalfMultinomialLoss() elif loss_function == "log": loss = CyHalfBinomialLoss() elif loss_function == "squared": @@ -411,7 +416,12 @@ def sag{{name_suffix}}( # compute the gradient for this sample, given the prediction if multinomial: - multiloss.cy_gradient(y, &prediction[0], n_classes, sample_weight, &gradient[0]) + multiloss.cy_gradient( + y_true=y, + raw_prediction=prediction, + sample_weight=sample_weight, + gradient_out=gradient, + ) else: gradient[0] = loss.cy_gradient(y, prediction[0]) * sample_weight From 0e3db396d8b35389b6d0c2445171f2ff01327a13 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 31 Dec 2023 09:57:33 +0100 Subject: [PATCH 08/12] MNT remove SAGA MultinomialLogLoss --- sklearn/linear_model/_sag_fast.pyx.tp | 202 ------------------------- sklearn/linear_model/tests/test_sag.py | 86 +---------- 2 files changed, 1 insertion(+), 287 deletions(-) diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 4cbd208f4e778..a6e8ee6ef169a 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -57,137 +57,6 @@ cdef inline {{c_type}} fmax{{name_suffix}}({{c_type}} x, {{c_type}} y) noexcept {{endfor}} - -{{for name_suffix, c_type, np_type in dtypes}} - -cdef {{c_type}} _logsumexp{{name_suffix}}({{c_type}}* arr, int n_classes) noexcept nogil: - """Computes the sum of arr assuming arr is in the log domain. - - Returns log(sum(exp(arr))) while minimizing the possibility of - over/underflow. - """ - # Use the max to normalize, as with the log this is what accumulates - # the less errors - cdef {{c_type}} vmax = arr[0] - cdef {{c_type}} out = 0.0 - cdef int i - - for i in range(1, n_classes): - if vmax < arr[i]: - vmax = arr[i] - - for i in range(n_classes): - out += exp(arr[i] - vmax) - - return log(out) + vmax - -{{endfor}} - - -{{for name_suffix, c_type, np_type in dtypes}} - -cdef class MultinomialLogLoss{{name_suffix}}: - cdef {{c_type}} cy_loss(self, {{c_type}} y, {{c_type}}* prediction, int n_classes, - {{c_type}} sample_weight) noexcept nogil: - r"""Multinomial Logistic regression loss. - - The multinomial logistic loss for one sample is: - loss = - sw \sum_c \delta_{y,c} (prediction[c] - logsumexp(prediction)) - = sw (logsumexp(prediction) - prediction[y]) - - where: - prediction = dot(x_sample, weights) + intercept - \delta_{y,c} = 1 if (y == c) else 0 - sw = sample_weight - - Parameters - ---------- - y : {{c_type}}, between 0 and n_classes - 1 - Indice of the correct class for current sample (i.e. label encoded). - - prediction : pointer to a np.ndarray[{{c_type}}] of shape (n_classes,) - Prediction of the multinomial classifier, for current sample. - - n_classes : integer - Total number of classes. - - sample_weight : {{c_type}} - Weight of current sample. - - Returns - ------- - loss : {{c_type}} - Multinomial loss for current sample. - - Reference - --------- - Bishop, C. M. (2006). Pattern recognition and machine learning. - Springer. (Chapter 4.3.4) - """ - cdef {{c_type}} logsumexp_prediction = _logsumexp{{name_suffix}}(prediction, n_classes) - cdef {{c_type}} loss - - # y is the indice of the correct class of current sample. - loss = (logsumexp_prediction - prediction[int(y)]) * sample_weight - return loss - - cdef void cy_gradient(self, {{c_type}} y, {{c_type}}* prediction, int n_classes, - {{c_type}} sample_weight, {{c_type}}* gradient_ptr) noexcept nogil: - r"""Multinomial Logistic regression gradient of the loss. - - The gradient of the multinomial logistic loss with respect to a class c, - and for one sample is: - grad_c = - sw * (p[c] - \delta_{y,c}) - - where: - p[c] = exp(logsumexp(prediction) - prediction[c]) - prediction = dot(sample, weights) + intercept - \delta_{y,c} = 1 if (y == c) else 0 - sw = sample_weight - - Note that to obtain the true gradient, this value has to be multiplied - by the sample vector x. - - Parameters - ---------- - prediction : pointer to a np.ndarray[{{c_type}}] of shape (n_classes,) - Prediction of the multinomial classifier, for current sample. - - y : {{c_type}}, between 0 and n_classes - 1 - Indice of the correct class for current sample (i.e. label encoded) - - n_classes : integer - Total number of classes. - - sample_weight : {{c_type}} - Weight of current sample. - - gradient_ptr : pointer to a np.ndarray[{{c_type}}] of shape (n_classes,) - Gradient vector to be filled. - - Reference - --------- - Bishop, C. M. (2006). Pattern recognition and machine learning. - Springer. (Chapter 4.3.4) - """ - cdef {{c_type}} logsumexp_prediction = _logsumexp{{name_suffix}}(prediction, n_classes) - cdef int class_ind - - for class_ind in range(n_classes): - gradient_ptr[class_ind] = exp(prediction[class_ind] - - logsumexp_prediction) - - # y is the indice of the correct class of current sample. - if class_ind == y: - gradient_ptr[class_ind] -= 1.0 - - gradient_ptr[class_ind] *= sample_weight - - def __reduce__(self): - return MultinomialLogLoss{{name_suffix}}, () - -{{endfor}} - {{for name_suffix, c_type, np_type in dtypes}} cdef inline {{c_type}} _soft_thresholding{{name_suffix}}({{c_type}} x, {{c_type}} shrinkage) noexcept nogil: @@ -784,74 +653,3 @@ cdef void predict_sample{{name_suffix}}( {{endfor}} - - -def _multinomial_grad_loss_all_samples( - SequentialDataset64 dataset, - double[:, ::1] weights_array, - double[::1] intercept_array, - int n_samples, - int n_features, - int n_classes -): - """Compute multinomial gradient and loss across all samples. - - Used for testing purpose only. - """ - cdef double *x_data_ptr = NULL - cdef int *x_ind_ptr = NULL - cdef int xnnz = -1 - cdef double y - cdef double sample_weight - - cdef double wscale = 1.0 - cdef int i, j, class_ind, feature_ind - cdef double val - cdef double sum_loss = 0.0 - - cdef MultinomialLogLoss64 multiloss = MultinomialLogLoss64() - - cdef double[:, ::1] sum_gradient_array = np.zeros((n_features, n_classes), dtype=np.double, order="c") - cdef double* sum_gradient = &sum_gradient_array[0, 0] - - cdef double[::1] prediction = np.zeros(n_classes, dtype=np.double, order="c") - - cdef double[::1] gradient = np.zeros(n_classes, dtype=np.double, order="c") - - with nogil: - for i in range(n_samples): - # get next sample on the dataset - dataset.next( - &x_data_ptr, - &x_ind_ptr, - &xnnz, - &y, - &sample_weight - ) - - # prediction of the multinomial classifier for the sample - predict_sample64( - x_data_ptr, - x_ind_ptr, - xnnz, - &weights_array[0, 0], - wscale, - &intercept_array[0], - &prediction[0], - n_classes - ) - - # compute the gradient for this sample, given the prediction - multiloss.cy_gradient(y, &prediction[0], n_classes, sample_weight, &gradient[0]) - - # compute the loss for this sample, given the prediction - sum_loss += multiloss.cy_loss(y, &prediction[0], n_classes, sample_weight) - - # update the sum of the gradient - for j in range(xnnz): - feature_ind = x_ind_ptr[j] - val = x_data_ptr[j] - for class_ind in range(n_classes): - sum_gradient[feature_ind * n_classes + class_ind] += gradient[class_ind] * val - - return sum_loss, sum_gradient_array diff --git a/sklearn/linear_model/tests/test_sag.py b/sklearn/linear_model/tests/test_sag.py index b1f675a54196e..995f7982bdbc2 100644 --- a/sklearn/linear_model/tests/test_sag.py +++ b/sklearn/linear_model/tests/test_sag.py @@ -8,17 +8,12 @@ import numpy as np import pytest -from scipy.special import logsumexp -from sklearn._loss.loss import HalfMultinomialLoss from sklearn.base import clone from sklearn.datasets import load_iris, make_blobs, make_classification from sklearn.linear_model import LogisticRegression, Ridge -from sklearn.linear_model._base import make_dataset -from sklearn.linear_model._linear_loss import LinearModelLoss from sklearn.linear_model._sag import get_auto_step_size -from sklearn.linear_model._sag_fast import _multinomial_grad_loss_all_samples -from sklearn.preprocessing import LabelBinarizer, LabelEncoder +from sklearn.preprocessing import LabelEncoder from sklearn.utils import check_random_state, compute_class_weight from sklearn.utils._testing import ( assert_allclose, @@ -926,85 +921,6 @@ def test_step_size_alpha_error(): clf2.fit(X, y) -def test_multinomial_loss(): - # test if the multinomial loss and gradient computations are consistent - X, y = iris.data, iris.target.astype(np.float64) - n_samples, n_features = X.shape - n_classes = len(np.unique(y)) - - rng = check_random_state(42) - weights = rng.randn(n_features, n_classes) - intercept = rng.randn(n_classes) - sample_weights = np.abs(rng.randn(n_samples)) - - # compute loss and gradient like in multinomial SAG - dataset, _ = make_dataset(X, y, sample_weights, random_state=42) - loss_1, grad_1 = _multinomial_grad_loss_all_samples( - dataset, weights, intercept, n_samples, n_features, n_classes - ) - # compute loss and gradient like in multinomial LogisticRegression - loss = LinearModelLoss( - base_loss=HalfMultinomialLoss(n_classes=n_classes), - fit_intercept=True, - ) - weights_intercept = np.vstack((weights, intercept)).T - loss_2, grad_2 = loss.loss_gradient( - weights_intercept, X, y, l2_reg_strength=0.0, sample_weight=sample_weights - ) - grad_2 = grad_2[:, :-1].T - # convert to same convention, i.e. LinearModelLoss uses average(loss, weight=sw) - loss_2 *= np.sum(sample_weights) - grad_2 *= np.sum(sample_weights) - - # comparison - assert_array_almost_equal(grad_1, grad_2) - assert_almost_equal(loss_1, loss_2) - - -def test_multinomial_loss_ground_truth(): - # n_samples, n_features, n_classes = 4, 2, 3 - n_classes = 3 - X = np.array([[1.1, 2.2], [2.2, -4.4], [3.3, -2.2], [1.1, 1.1]]) - y = np.array([0, 1, 2, 0], dtype=np.float64) - lbin = LabelBinarizer() - Y_bin = lbin.fit_transform(y) - - weights = np.array([[0.1, 0.2, 0.3], [1.1, 1.2, -1.3]]) - intercept = np.array([1.0, 0, -0.2]) - sample_weights = np.array([0.8, 1, 1, 0.8]) - - prediction = np.dot(X, weights) + intercept - logsumexp_prediction = logsumexp(prediction, axis=1) - p = prediction - logsumexp_prediction[:, np.newaxis] - loss_1 = -(sample_weights[:, np.newaxis] * p * Y_bin).sum() - diff = sample_weights[:, np.newaxis] * (np.exp(p) - Y_bin) - grad_1 = np.dot(X.T, diff) - - loss = LinearModelLoss( - base_loss=HalfMultinomialLoss(n_classes=n_classes), - fit_intercept=True, - ) - weights_intercept = np.vstack((weights, intercept)).T - loss_2, grad_2 = loss.loss_gradient( - weights_intercept, X, y, l2_reg_strength=0.0, sample_weight=sample_weights - ) - grad_2 = grad_2[:, :-1].T - # convert to same convention, i.e. LinearModelLoss uses average(loss, weight=sw) - loss_2 *= np.sum(sample_weights) - grad_2 *= np.sum(sample_weights) - - assert_almost_equal(loss_1, loss_2) - assert_array_almost_equal(grad_1, grad_2) - - # ground truth - loss_gt = 11.680360354325961 - grad_gt = np.array( - [[-0.557487, -1.619151, +2.176638], [-0.903942, +5.258745, -4.354803]] - ) - assert_almost_equal(loss_1, loss_gt) - assert_array_almost_equal(grad_1, grad_gt) - - @pytest.mark.parametrize("solver", ["sag", "saga"]) def test_sag_classifier_raises_error(solver): # Following #13316, the error handling behavior changed in cython sag. This From dfc84f46892ebab6f6ff78b52a027c6f9c457516 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Wed, 24 Jul 2024 18:00:36 +0500 Subject: [PATCH 09/12] Update sklearn/_loss/_loss.pyx.tp --- sklearn/_loss/_loss.pyx.tp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/_loss/_loss.pyx.tp b/sklearn/_loss/_loss.pyx.tp index 2baa06aa71592..799f206b5dcce 100644 --- a/sklearn/_loss/_loss.pyx.tp +++ b/sklearn/_loss/_loss.pyx.tp @@ -289,7 +289,7 @@ cdef inline double_pair sum_exp_minus_max( cdef: int k int n_classes = raw_prediction.shape[1] - double_pair max_value_and_sum_exps # va1 = max_value, val2 = sum_exps + double_pair max_value_and_sum_exps # val1 = max_value, val2 = sum_exps max_value_and_sum_exps.val1 = raw_prediction[i, 0] max_value_and_sum_exps.val2 = 0 From 59b1df63aad6ced832daecb34899b5705a692a20 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Wed, 24 Jul 2024 18:00:43 +0500 Subject: [PATCH 10/12] Update sklearn/_loss/_loss.pyx.tp --- sklearn/_loss/_loss.pyx.tp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/_loss/_loss.pyx.tp b/sklearn/_loss/_loss.pyx.tp index 799f206b5dcce..581c7c78a34d9 100644 --- a/sklearn/_loss/_loss.pyx.tp +++ b/sklearn/_loss/_loss.pyx.tp @@ -1161,7 +1161,7 @@ cdef class CyHalfMultinomialLoss(): const floating_in sample_weight, floating_out[::1] gradient_out, # OUT ) noexcept nogil: - """Compute gradient of loss w.r.t. raw_prediction for a single sample. + """Compute gradient of loss w.r.t. `raw_prediction` for a single sample. The gradient of the multinomial logistic loss with respect to a class k, and for one sample is: From 0d4fbe5c2e4907c64d3cfaedf9716cce9af21393 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Wed, 24 Jul 2024 18:00:49 +0500 Subject: [PATCH 11/12] Update sklearn/_loss/tests/test_loss.py --- sklearn/_loss/tests/test_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/_loss/tests/test_loss.py b/sklearn/_loss/tests/test_loss.py index 0a62bd4c382cf..403274115fdd9 100644 --- a/sklearn/_loss/tests/test_loss.py +++ b/sklearn/_loss/tests/test_loss.py @@ -1069,11 +1069,11 @@ def test_multinomial_loss_fit_intercept_only(): def test_multinomial_cy_gradient(global_random_seed): - """Test that Multinomial cy_gradient gives the same as gradient. + """Test that Multinomial cy_gradient gives the same result as gradient. CyHalfMultinomialLoss does not inherit from CyLossFunction and has a different API. - As a consequence, the functions like `loss` and `gradient` do no rely on `cy_loss` - and cy_gradient. + As a consequence, the functions like `loss` and `gradient` do not rely on `cy_loss` + and `cy_gradient`. """ n_samples = 100 n_classes = 5 From 09a6cc2b73590d072281a70f73e446d7771635c4 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Wed, 24 Jul 2024 21:50:21 +0500 Subject: [PATCH 12/12] Address PR suggestions --- sklearn/_loss/_loss.pyx.tp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/_loss/_loss.pyx.tp b/sklearn/_loss/_loss.pyx.tp index 581c7c78a34d9..56d3aebb6c6f1 100644 --- a/sklearn/_loss/_loss.pyx.tp +++ b/sklearn/_loss/_loss.pyx.tp @@ -278,9 +278,7 @@ cdef inline double_pair sum_exp_minus_max( # return.val2 = sum_exps = sum(p[k], k = 0 to n_classes-1) = sum of exponentials # len(p) must be n_classes # Notes: - # - Using "by reference" arguments doesn't work well, see - # https://github.com/cython/cython/issues/1863 - # Therefore we return a double_pair and also store in p. + # - We return the max value and sum of exps (stored in p) as a double_pair. # - i needs to be passed (and stays constant) because otherwise Cython does # not generate optimal code, see # https://github.com/scikit-learn/scikit-learn/issues/17299 @@ -1136,7 +1134,7 @@ cdef class {{name}}(CyLossFunction): # The multinomial deviance loss is also known as categorical cross-entropy or # multinomial log-likelihood. -# Here, we do not inherit from CyLossFunction as it's cy_gradient method deviates +# Here, we do not inherit from CyLossFunction as its cy_gradient method deviates # from the API. cdef class CyHalfMultinomialLoss(): """Half Multinomial deviance loss with multinomial logit link.