8000 MNT replace Cython loss functions in SGD part 3 (#28037) · scikit-learn/scikit-learn@05c0992 · GitHub
[go: up one dir, main page]

Skip to content

Commit 05c0992

Browse files
lorentzenchrOmarManzoorOmar Salman
authored
MNT replace Cython loss functions in SGD part 3 (#28037)
Co-authored-by: Omar Salman <omar.salman@arbisoft.com> Co-authored-by: Omar Salman <omar.salman2007@gmail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft>
1 parent 156ef1b commit 05c0992

File tree

5 files changed

+186
-343
lines changed

5 files changed

+186
-343
lines changed

sklearn/_loss/_loss.pxd

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,13 @@ cdef class CyExponentialLoss(CyLossFunction):
8989
cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil
9090
cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil
9191
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) noexcept nogil
92+
93+
94+
cdef class CyHalfMultinomialLoss():
95+
cdef void cy_gradient(
96+
self,
97+
const floating_in y_true,
98+
const floating_in[::1] raw_prediction,
99+
const floating_in sample_weight,
100+
floating_out[::1] gradient_out,
101+
) noexcept nogil

sklearn/_loss/_loss.pyx.tp

Lines changed: 131 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -266,20 +266,19 @@ cdef inline double log1pexp(double x) noexcept nogil:
266266
return x
267267

268268

269-
cdef inline void sum_exp_minus_max(
269+
cdef inline double_pair sum_exp_minus_max(
270270
const int i,
271271
const floating_in[:, :] raw_prediction, # IN
272-
floating_in *p # OUT
272+
floating_out *p # OUT
273273
) noexcept nogil:
274-
# Thread local buffers are used to store results of this function via p.
274+
# Thread local buffers are used to store part of the results via p.
275275
# The results are stored as follows:
276276
# p[k] = exp(raw_prediction_i_k - max_value) for k = 0 to n_classes-1
277-
# p[-2] = max(raw_prediction_i_k, k = 0 to n_classes-1)
278-
# p[-1] = sum(p[k], k = 0 to n_classes-1) = sum of exponentials
279-
# len(p) must be n_classes + 2
277+
# return.val1 = max_value = max(raw_prediction_i_k, k = 0 to n_classes-1)
278+
# return.val2 = sum_exps = sum(p[k], k = 0 to n_classes-1) = sum of exponentials
279+
# len(p) must be n_classes
280280
# Notes:
281-
# - Using "by reference" arguments doesn't work well, therefore we use a
282-
# longer p, see https://github.com/cython/cython/issues/1863
281+
# - We return the max value and sum of exps (stored in p) as a double_pair.
283282
# - i needs to be passed (and stays constant) because otherwise Cython does
284283
# not generate optimal code, see
285284
# https://github.com/scikit-learn/scikit-learn/issues/17299
@@ -288,19 +287,20 @@ cdef inline void sum_exp_minus_max(
288287
cdef:
289288
int k
290289
int n_classes = raw_prediction.shape[1]
291-
double max_value = raw_prediction[i, 0]
292-
double sum_exps = 0
290+
double_pair max_value_and_sum_exps # val1 = max_value, val2 = sum_exps
291+
292+
max_value_and_sum_exps.val1 = raw_prediction[i, 0]
293+
max_value_and_sum_exps.val2 = 0
293294
for k in range(1, n_classes):
294295
# Compute max value of array for numerical stability
295-
if max_value < raw_prediction[i, k]:
296-
max_value = raw_prediction[i, k]
296+
if max_value_and_sum_exps.val1 < raw_prediction[i, k]:
297+
max_value_and_sum_exps.val1 = raw_prediction[i, k]
297298

298299
for k in range(n_classes):
299-
p[k] = exp(raw_prediction[i, k] - max_value)
300-
sum_exps += p[k]
300+
p[k] = exp(raw_prediction[i, k] - max_value_and_sum_exps.val1)
301+
max_value_and_sum_exps.val2 += p[k]
301302

302-
p[n_classes] = max_value # same as p[-2]
303-
p[n_classes + 1] = sum_exps # same as p[-1]
303+
return max_value_and_sum_exps
304304

305305

306306
# -------------------------------------
@@ -1133,8 +1133,10 @@ cdef class {{name}}(CyLossFunction):
11331133

11341134

11351135
# The multinomial deviance loss is also known as categorical cross-entropy or
1136-
# multinomial log-likelihood
1137-
cdef class CyHalfMultinomialLoss(CyLossFunction):
1136+
# multinomial log-likelihood.
1137+
# Here, we do not inherit from CyLossFunction as its cy_gradient method deviates
1138+
# from the API.
1139+
cdef class CyHalfMultinomialLoss():
11381140
"""Half Multinomial deviance loss with multinomial logit link.
11391141

11401142
Domain:
@@ -1148,6 +1150,78 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
11481150
mapped to (y_true == k) for k = 0 .. n_classes - 1 which is either 0 or 1.
11491151
"""
11501152

1153+
# Here we deviate from the CyLossFunction API. SAG/SAGA needs direct access to
1154+
# sample-wise gradients which we provide here.
1155+
cdef inline void cy_gradient(
1156+
self,
1157+
const floating_in y_true,
1158+
const floating_in[::1] raw_prediction, # IN
1159+
const floating_in sample_weight,
1160+
floating_out[::1] gradient_out, # OUT
1161+
) noexcept nogil:
1162+
"""Compute gradient of loss w.r.t. `raw_prediction` for a single sample.
1163+
1164+
The gradient of the multinomial logistic loss with respect to a class k,
1165+
and for one sample is:
1166+
grad_k = - sw * (p[k] - (y==k))
1167+
1168+
where:
1169+
p[k] = proba[k] = exp(raw_prediction[k] - logsumexp(raw_prediction))
1170+
sw = sample_weight
1171+
1172+
Parameters
1173+
----------
1174+
y_true : double
1175+
Observed, true target value.
1176+
raw_prediction : array of shape (n_classes,)
1177+
Raw prediction values (in link space).
1178+
sample_weight : double
1179+
Sample weight.
1180+
gradient_out : array of shape (n_classs,)
1181+
A location into which the gradient is stored.
1182+
1183+
Returns
1184+
-------
1185+
gradient : double
1186+
The derivative of the loss function w.r.t. `raw_prediction`.
1187+
"""
1188+
cdef:
1189+
int k
1190+
int n_classes = raw_prediction.shape[0]
1191+
double_pair max_value_and_sum_exps
1192+
const floating_in[:, :] raw = raw_prediction[None, :]
1193+
1194+
max_value_and_sum_exps = sum_exp_minus_max(0, raw, &gradient_out[0])
1195+
for k in range(n_classes):
1196+
# gradient_out[k] = p_k = y_pred_k = prob of class k
1197+
gradient_out[k] /= max_value_and_sum_exps.val2
1198+
# gradient_k = (p_k - (y_true == k)) * sw
1199+
gradient_out[k] = (gradient_out[k] - (y_true == k)) * sample_weight
1200+
1201+
def _test_cy_gradient(
1202+
self,
1203+
const floating_in[::1] y_true, # IN
1204+
const floating_in[:, ::1] raw_prediction, # IN
1205+
const floating_in[::1] sample_weight, # IN
1206+
):
1207+
"""For testing only."""
1208+
cdef:
1209+
int i, k
1210+
int n_samples = y_true.shape[0]
1211+
int n_classes = raw_prediction.shape[1]
1212+
floating_in [:, ::1] gradient_out
1213+
gradient = np.empty((n_samples, n_classes), dtype=np.float64)
1214+
gradient_out = gradient
1215+
1216+
for i in range(n_samples):
1217+
self.cy_gradient(
1218+
y_true=y_true[i],
1219+
raw_prediction=raw_prediction[i, :],
1220+
sample_weight=1.0 if sample_weight is None else sample_weight[i],
1221+
gradient_out=gradient_out[i, :],
1222+
)
1223+
return gradient
1224+
11511225
# Note that we do not assume memory alignment/contiguity of 2d arrays.
11521226
# There seems to be little benefit in doing so. Benchmarks proofing the
11531227
# opposite are welcome.
@@ -1165,6 +1239,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
11651239
int n_classes = raw_prediction.shape[1]
11661240
floating_in max_value, sum_exps
11671241
floating_in* p # temporary buffer
1242+
double_pair max_value_and_sum_exps
11681243

11691244
# We assume n_samples > n_classes. In this case having the inner loop
11701245
# over n_classes is a good default.
@@ -1176,12 +1251,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
11761251
with nogil, parallel(num_threads=n_threads):
11771252
# Define private buffer variables as each thread might use its
11781253
# own.
1179-
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
1254+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
11801255

11811256
for i in prange(n_samples, schedule='static'):
1182-
sum_exp_minus_max(i, raw_prediction, p)
1183-
max_value = p[n_classes] # p[-2]
1184-
sum_exps = p[n_classes + 1] # p[-1]
1257+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1258+
max_value = max_value_and_sum_exps.val1
1259+
sum_exps = max_value_and_sum_exps.val2
11851260
loss_out[i] = log(sum_exps) + max_value
11861261

11871262
# label encoded y_true
@@ -1191,12 +1266,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
11911266
free(p)
11921267
else:
11931268
with nogil, parallel(num_threads=n_threads):
1194-
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
1269+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
11951270

11961271
for i in prange(n_samples, schedule='static'):
1197-
sum_exp_minus_max(i, raw_prediction, p)
1198-
max_value = p[n_classes] # p[-2]
1199-
sum_exps = p[n_classes + 1] # p[-1]
1272+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1273+
max_value = max_value_and_sum_exps.val1
1274+
sum_exps = max_value_and_sum_exps.val2
12001275
loss_out[i] = log(sum_exps) + max_value
12011276

12021277
# label encoded y_true
@@ -1222,18 +1297,19 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
12221297
int n_classes = raw_prediction.shape[1]
12231298
floating_in max_value, sum_exps
12241299
floating_in* p # temporary buffer
1300+
double_pair max_value_and_sum_exps
12251301

12261302
if sample_weight is None:
12271303
# inner loop over n_classes
12281304
with nogil, parallel(num_threads=n_threads):
12291305
# Define private buffer variables as each thread might use its
12301306
# own.
1231-
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
1307+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
12321308

12331309
for i in prange(n_samples, schedule='static'):
1234-
sum_exp_minus_max(i, raw_prediction, p)
1235-
max_value = p[n_classes] # p[-2]
1236-
sum_exps = p[n_classes + 1] # p[-1]
1310+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1311+
max_value = max_value_and_sum_exps.val1
1312+
sum_exps = max_value_and_sum_exps.val2
12371313
loss_out[i] = log(sum_exps) + max_value
12381314

12391315
for k in range(n_classes):
@@ -1247,12 +1323,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
12471323
free(p)
12481324
else:
12491325
with nogil, parallel(num_threads=n_threads):
1250-
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
1326+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
12511327

12521328
for i in prange(n_samples, schedule='static'):
1253-
sum_exp_minus_max(i, raw_prediction, p)
1254-
max_value = p[n_classes] # p[-2]
1255-
sum_exps = p[n_classes + 1] # p[-1]
1329+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1330+
max_value = max_value_and_sum_exps.val1
1331+
sum_exps = max_value_and_sum_exps.val2
12561332
loss_out[i] = log(sum_exps) + max_value
12571333

12581334
for k in range(n_classes):
@@ -1281,17 +1357,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
12811357
int n_classes = raw_prediction.shape[1]
12821358
floating_in sum_exps
12831359
floating_in* p # temporary buffer
1360+
double_pair max_value_and_sum_exps
12841361

12851362
if sample_weight is None:
12861363
# inner loop over n_classes
12871364
with nogil, parallel(num_threads=n_threads):
12881365
# Define private buffer variables as each thread might use its
12891366
# own.
1290-
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
1367+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
12911368

12921369
for i in prange(n_samples, schedule='static'):
1293-
sum_exp_minus_max(i, raw_prediction, p)
1294-
sum_exps = p[n_classes + 1] # p[-1]
1370+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1371+
sum_exps = max_value_and_sum_exps.val2
12951372

12961373
for k in range(n_classes):
12971374
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
@@ -1301,11 +1378,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
13011378
free(p)
13021379
else:
13031380
with nogil, parallel(num_threads=n_threads):
1304-
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
1381+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
13051382

13061383
for i in prange(n_samples, schedule='static'):
1307-
sum_exp_minus_max(i, raw_prediction, p)
1308-
sum_exps = p[n_classes + 1] # p[-1]
1384+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1385+
sum_exps = max_value_and_sum_exps.val2
13091386

13101387
for k in range(n_classes):
13111388
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
@@ -1329,17 +1406,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
13291406
int n_classes = raw_prediction.shape[1]
13301407
floating_in sum_exps
13311408
floating_in* p # temporary buffer
1409+
double_pair max_value_and_sum_exps
13321410

13331411
if sample_weight is None:
13341412
# inner loop over n_classes
13351413
with nogil, parallel(num_threads=n_threads):
13361414
# Define private buffer variables as each thread might use its
13371415
# own.
1338-
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
1416+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
13391417

13401418
for i in prange(n_samples, schedule='static'):
1341-
sum_exp_minus_max(i, raw_prediction, p)
1342-
sum_exps = p[n_classes + 1] # p[-1]
1419+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1420+
sum_exps = max_value_and_sum_exps.val2
13431421

13441422
for k in range(n_classes):
13451423
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
@@ -1351,11 +1429,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
13511429
free(p)
13521430
else:
13531431
with nogil, parallel(num_threads=n_threads):
1354-
p = <floating_in *> mall F41A oc(sizeof(floating_in) * (n_classes + 2))
1432+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
13551433

13561434
for i in prange(n_samples, schedule='static'):
1357-
sum_exp_minus_max(i, raw_prediction, p)
1358-
sum_exps = p[n_classes + 1] # p[-1]
1435+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1436+
sum_exps = max_value_and_sum_exps.val2
13591437

13601438
for k in range(n_classes):
13611439
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
@@ -1384,17 +1462,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
13841462
int n_classes = raw_prediction.shape[1]
13851463
floating_in sum_exps
13861464
floating_in* p # temporary buffer
1465+
double_pair max_value_and_sum_exps
13871466

13881467
if sample_weight is None:
13891468
# inner loop over n_classes
13901469
with nogil, parallel(num_threads=n_threads):
13911470
# Define private buffer variables as each thread might use its
13921471
# own.
1393-
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
1472+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
13941473

13951474
for i in prange(n_samples, schedule='static'):
1396-
sum_exp_minus_max(i, raw_prediction, p)
1397-
sum_exps = p[n_classes + 1] # p[-1]
1475+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1476+
sum_exps = max_value_and_sum_exps.val2
13981477

13991478
for k in range(n_classes):
14001479
proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k
@@ -1404,11 +1483,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
14041483
free(p)
14051484
else:
14061485
with nogil, parallel(num_threads=n_threads):
1407-
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
1486+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
14081487

14091488
for i in prange(n_samples, schedule='static'):
1410-
sum_exp_minus_max(i, raw_prediction, p)
1411-
sum_exps = p[n_classes + 1] # p[-1]
1489+
max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1490+
sum_exps = max_value_and_sum_exps.val2
14121491

14131492
for k in range(n_classes):
14141493
proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k

sklearn/_loss/tests/test_loss.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,36 @@ def test_multinomial_loss_fit_intercept_only():
10681068
assert_all_finite(baseline_prediction)
10691069

10701070

1071+
def test_multinomial_cy_gradient(global_random_seed):
1072+
"""Test that Multinomial cy_gradient gives the same result as gradient.
1073+
1074+
CyHalfMultinomialLoss does not inherit from CyLossFunction and has a different API.
1075+
As a consequence, the functions like `loss` and `gradient` do not rely on `cy_loss`
1076+
and `cy_gradient`.
1077+
"""
1078+
n_samples = 100
1079+
n_classes = 5
1080+
loss = HalfMultinomialLoss(n_classes=n_classes)
1081+
y_true, raw_prediction = random_y_true_raw_prediction(
1082+
loss=loss,
1083+
n_samples=n_samples,
1084+
seed=global_random_seed,
1085+
)
1086+
sample_weight = np.linspace(0.1, 2, num=n_samples)
1087+
1088+
grad1 = loss.closs._test_cy_gradient(
1089+
y_true=y_true,
1090+
raw_prediction=raw_prediction, # needs to be C-contiguous
1091+
sample_weight=sample_weight,
1092+
)
1093+
grad2 = loss.gradient(
1094+
y_true=y_true,
1095+
raw_prediction=raw_prediction,
1096+
sample_weight=sample_weight,
1097+
)
1098+
assert_allclose(grad1, grad2)
1099+
1100+
10711101
def test_binomial_and_multinomial_loss(global_random_seed):
10721102
"""Test that multinomial loss with n_classes = 2 is the same as binomial loss."""
10731103
rng = np.random.RandomState(global_random_seed)

0 commit comments

Comments
 (0)
0