8000 MNT replace Cython loss functions in SGD part 3 by lorentzenchr · Pull Request #28037 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MNT replace Cython loss functions in SGD part 3 #28037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sklearn/_loss/_loss.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,13 @@ cdef class CyExponentialLoss(CyLossFunction):
cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil
cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) noexcept nogil


cdef class CyHalfMultinomialLoss():
cdef void cy_gradient(
self,
const floating_in y_true,
const floating_in[::1] raw_prediction,
const floating_in sample_weight,
floating_out[::1] gradient_out,
) noexcept nogil
Comment on lines +94 to +101
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is much nicer to be able to use fused types directly!

183 changes: 131 additions & 52 deletions sklearn/_loss/_loss.pyx.tp
10000
Original file line number Diff line number Diff line change
Expand Up @@ -266,20 +266,19 @@ cdef inline double log1pexp(double x) noexcept nogil:
return x


cdef inline void sum_exp_minus_max(
cdef inline double_pair sum_exp_minus_max(
const int i,
const floating_in[:, :] raw_prediction, # IN
floating_in *p # OUT
floating_out *p # OUT
) noexcept nogil:
# Thread local buffers are used to store results of this function via p.
# Thread local buffers are used to store part of the results via p.
# The results are stored as follows:
# p[k] = exp(raw_prediction_i_k - max_value) for k = 0 to n_classes-1
# p[-2] = max(raw_prediction_i_k, k = 0 to n_classes-1)
# p[-1] = sum(p[k], k = 0 to n_classes-1) = sum of exponentials
# len(p) must be n_classes + 2
# return.val1 = max_value = max(raw_prediction_i_k, k = 0 to n_classes-1)
# return.val2 = sum_exps = sum(p[k], k = 0 to n_classes-1) = sum of exponentials
# len(p) must be n_classes
# Notes:
# - Using "by reference" arguments doesn't work well, therefore we use a
# longer p, see https://github.com/cython/cython/issues/1863
# - We return the max value and sum of exps (stored in p) as a double_pair.
# - i needs to be passed (and stays constant) because otherwise Cython does
# not generate optimal code, see
# https://github.com/scikit-learn/scikit-learn/issues/17299
Expand All @@ -288,19 +287,20 @@ cdef inline void sum_exp_minus_max(
cdef:
int k
int n_classes = raw_prediction.shape[1]
double max_value = raw_prediction[i, 0]
double sum_exps = 0
double_pair max_value_and_sum_exps # val1 = max_value, val2 = sum_exps

max_value_and_sum_exps.val1 = raw_prediction[i, 0]
max_value_and_sum_exps.val2 = 0
for k in range(1, n_classes):
# Compute max value of array for numerical stability
if max_value < raw_prediction[i, k]:
max_value = raw_prediction[i, k]
if max_value_and_sum_exps.val1 < raw_prediction[i, k]:
max_value_and_sum_exps.val1 = raw_prediction[i, k]

for k in range(n_classes):
p[k] = exp(raw_prediction[i, k] - max_value)
sum_exps += p[k]
p[k] = exp(raw_prediction[i, k] - max_value_and_sum_exps.val1)
max_value_and_sum_exps.val2 += p[k]

p[n_classes] = max_value # same as p[-2]
p[n_classes + 1] = sum_exps # same as p[-1]
return max_value_and_sum_exps


# -------------------------------------
Expand Down Expand Up @@ -1133,8 +1133,10 @@ cdef class {{name}}(CyLossFunction):


# The multinomial deviance loss is also known as categorical cross-entropy or
# multinomial log-likelihood
cdef class CyHalfMultinomialLoss(CyLossFunction):
# multinomial log-likelihood.
# Here, we do not inherit from CyLossFunction as its cy_gradient method deviates
# from the API.
cdef class CyHalfMultinomialLoss():
"""Half Multinomial deviance loss with multinomial logit link.

Domain:
Expand All @@ -1148,6 +1150,78 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
mapped to (y_true == k) for k = 0 .. n_classes - 1 which is either 0 or 1.
"""

# Here we deviate from the CyLossFunction API. SAG/SAGA needs direct access to
# sample-wise gradients which we provide here.
cdef inline void cy_gradient(
self,
const floating_in y_true,
const floating_in[::1] raw_prediction, # IN
const floating_in sample_weight,
floating_out[::1] gradient_out, # OUT
) noexcept nogil:
"""Compute gradient of loss w.r.t. `raw_prediction` for a single sample.

The gradient of the multinomial logistic loss with respect to a class k,
and for one sample is:
grad_k = - sw * (p[k] - (y==k))

where:
p[k] = proba[k] = exp(raw_prediction[k] - logsumexp(raw_prediction))
sw = sample_weight

Parameters
----------
y_true : double
Observed, true target value.
raw_prediction : array of shape (n_classes,)
Raw prediction values (in link space).
sample_weight : double
Sample weight.
gradient_out : array of shape (n_classs,)
A location into which the gradient is stored.

Returns
-------
gradient : double
The derivative of the loss function w.r.t. `raw_prediction`.
"""
cdef:
int k
int n_classes = raw_prediction.shape[0]
double_pair max_value_and_sum_exps
const floating_in[:, :] raw = raw_prediction[None, :]

max_value_and_sum_exps = sum_exp_minus_max(0, raw, &gradient_out[0])
for k in range(n_classes):
# gradient_out[k] = p_k = y_pred_k = prob of class k
gradient_out[k] /= max_value_and_sum_exps.val2
# gradient_k = (p_k - (y_true == k)) * sw
gradient_out[k] = (gradient_out[k] - (y_true == k)) * sample_weight

def _test_cy_gradient(
self,
const floating_in[::1] y_true, # IN
const floating_in[:, ::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
):
"""For testing only."""
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
floating_in [:, ::1] gradient_out
gradient = np.empty((n_samples, n_classes), dtype=np.float64)
gradient_out = gradient

for i in range(n_samples):
self.cy_gradient(
y_true=y_true[i],
raw_prediction=raw_prediction[i, :],
sample_weight=1.0 if sample_weight is None else sample_weight[i],
gradient_out=gradient_out[i, :],
)
return gradient

# Note that we do not assume memory alignment/contiguity of 2d arrays.
# There seems to be little benefit in doing so. Benchmarks proofing the
# opposite are welcome.
Expand All @@ -1165,6 +1239,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
int n_classes = raw_prediction.shape[1]
floating_in max_value, sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps

# We assume n_samples > n_classes. In this case having the inner loop
# over n_classes is a good default.
Expand All @@ -1176,12 +1251,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
max_value = p[n_classes] # p[-2]
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
max_value = max_value_and_sum_exps.val1
sum_exps = max_value_and_sum_exps.val2
loss_out[i] = log(sum_exps) + max_value

# label encoded y_true
Expand All @@ -1191,12 +1266,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
max_value = p[n_classes] # p[-2]
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
max_value = max_value_and_sum_exps.val1
sum_exps = max_value_and_sum_exps.val2
loss_out[i] = log(sum_exps) + max_value

# label encoded y_true
Expand All @@ -1222,18 +1297,19 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
int n_classes = raw_prediction.shape[1]
floating_in max_value, sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps

if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
max_value = p[n_classes] # p[-2]
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
max_value = max_value_and_sum_exps.val1
sum_exps = max_value_and_sum_exps.val2
loss_out[i] = log(sum_exps) + max_value

for k in range(n_classes):
Expand All @@ -1247,12 +1323,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

2364 for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
max_value = p[n_classes] # p[-2]
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
max_value = max_value_and_sum_exps.val1
sum_exps = max_value_and_sum_exps.val2
loss_out[i] = log(sum_exps) + max_value

for k in range(n_classes):
Expand Down Expand Up @@ -1281,17 +1357,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
int n_classes = raw_prediction.shape[1]
floating_in sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps

if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2

for k in range(n_classes):
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
Expand All @@ -1301,11 +1378,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2

for k in range(n_classes):
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
Expand All @@ -1329,17 +1406,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
int n_classes = raw_prediction.shape[1]
floating_in sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps

if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2

for k in range(n_classes):
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
Expand All @@ -1351,11 +1429,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2

for k in range(n_classes):
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
Expand Down Expand Up @@ -1384,17 +1462,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
int n_classes = raw_prediction.shape[1]
floating_in sum_exps
floating_in* p # temporary buffer
double_pair max_value_and_sum_exps

if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2

for k in range(n_classes):
proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k
Expand All @@ -1404,11 +1483,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
sum_exps = p[n_classes + 1] # p[-1]
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
sum_exps = max_value_and_sum_exps.val2

for k in range(n_classes):
proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k
Expand Down
30 changes: 30 additions & 0 deletions sklearn/_loss/tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,36 @@ def test_multinomial_loss_fit_intercept_only():
assert_all_finite(baseline_prediction)


def test_multinomial_cy_gradient(global_random_seed):
"""Test that Multinomial cy_gradient gives the same result as gradient.

CyHalfMultinomialLoss does not inherit from CyLossFunction and has a different API.
As a consequence, the functions like `loss` and `gradient` do not rely on `cy_loss`
and `cy_gradient`.
"""
n_samples = 100
n_classes = 5
loss = HalfMultinomialLoss(n_classes=n_classes)
y_true, raw_prediction = random_y_true_raw_prediction(
loss=loss,
n_samples=n_samples,
seed=global_random_seed,
)
sample_weight = np.linspace(0.1, 2, num=n_samples)

grad1 = loss.closs._test_cy_gradient(
y_true=y_true,
raw_prediction=raw_prediction, # needs to be C-contiguous
sample_weight=sample_weight,
)
grad2 = loss.gradient(
y_true=y_true,
raw_prediction=raw_prediction,
sample_weight=sample_weight,
)
assert_allclose(grad1, grad2)


def test_binomial_and_multinomial_loss(global_random_seed):
"""Test that multinomial loss with n_classes = 2 is the same as binomial loss."""
rng = np.random.RandomState(global_random_seed)
Expand Down
Loading
0