@@ -266,20 +266,19 @@ cdef inline double log1pexp(double x) noexcept nogil:
266
266
return x
267
267
268
268
269
- cdef inline void sum_exp_minus_max(
269
+ cdef inline double_pair sum_exp_minus_max(
270
270
const int i,
271
271
const floating_in[:, :] raw_prediction, # IN
272
- floating_in *p # OUT
272
+ floating_out *p # OUT
273
273
) noexcept nogil:
274
- # Thread local buffers are used to store results of this function via p.
274
+ # Thread local buffers are used to store part of the results via p.
275
275
# The results are stored as follows:
276
276
# p[k] = exp(raw_prediction_i_k - max_value) for k = 0 to n_classes-1
277
- # p[-2] = max(raw_prediction_i_k, k = 0 to n_classes-1)
278
- # p[-1] = sum(p[k], k = 0 to n_classes-1) = sum of exponentials
279
- # len(p) must be n_classes + 2
277
+ # return.val1 = max_value = max(raw_prediction_i_k, k = 0 to n_classes-1)
278
+ # return.val2 = sum_exps = sum(p[k], k = 0 to n_classes-1) = sum of exponentials
279
+ # len(p) must be n_classes
280
280
# Notes:
281
- # - Using "by reference" arguments doesn't work well, therefore we use a
282
- # longer p, see https://github.com/cython/cython/issues/1863
281
+ # - We return the max value and sum of exps (stored in p) as a double_pair.
283
282
# - i needs to be passed (and stays constant) because otherwise Cython does
284
283
# not generate optimal code, see
285
284
# https://github.com/scikit-learn/scikit-learn/issues/17299
@@ -288,19 +287,20 @@ cdef inline void sum_exp_minus_max(
288
287
cdef:
289
288
int k
290
289
int n_classes = raw_prediction.shape[1]
291
- double max_value = raw_prediction[i, 0]
292
- double sum_exps = 0
290
+ double_pair max_value_and_sum_exps # val1 = max_value, val2 = sum_exps
291
+
292
+ max_value_and_sum_exps.val1 = raw_prediction[i, 0]
293
+ max_value_and_sum_exps.val2 = 0
293
294
for k in range(1, n_classes):
294
295
# Compute max value of array for numerical stability
295
- if max_value < raw_prediction[i, k]:
296
- max_value = raw_prediction[i, k]
296
+ if max_value_and_sum_exps.val1 < raw_prediction[i, k]:
297
+ max_value_and_sum_exps.val1 = raw_prediction[i, k]
297
298
298
299
for k in range(n_classes):
299
- p[k] = exp(raw_prediction[i, k] - max_value )
300
- sum_exps += p[k]
300
+ p[k] = exp(raw_prediction[i, k] - max_value_and_sum_exps.val1 )
301
+ max_value_and_sum_exps.val2 += p[k]
301
302
302
- p[n_classes] = max_value # same as p[-2]
303
- p[n_classes + 1] = sum_exps # same as p[-1]
303
+ return max_value_and_sum_exps
304
304
305
305
306
306
# -------------------------------------
@@ -1133,8 +1133,10 @@ cdef class {{name}}(CyLossFunction):
1133
1133
1134
1134
1135
1135
# The multinomial deviance loss is also known as categorical cross-entropy or
1136
- # multinomial log-likelihood
1137
- cdef class CyHalfMultinomialLoss(CyLossFunction):
1136
+ # multinomial log-likelihood.
1137
+ # Here, we do not inherit from CyLossFunction as its cy_gradient method deviates
1138
+ # from the API.
1139
+ cdef class CyHalfMultinomialLoss():
1138
1140
"""Half Multinomial deviance loss with multinomial logit link.
1139
1141
1140
1142
Domain:
@@ -1148,6 +1150,78 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1148
1150
mapped to (y_true == k) for k = 0 .. n_classes - 1 which is either 0 or 1.
1149
1151
"""
1150
1152
1153
+ # Here we deviate from the CyLossFunction API. SAG/SAGA needs direct access to
1154
+ # sample-wise gradients which we provide here.
1155
+ cdef inline void cy_gradient(
1156
+ self,
1157
+ const floating_in y_true,
1158
+ const floating_in[::1] raw_prediction, # IN
1159
+ const floating_in sample_weight,
1160
+ floating_out[::1] gradient_out, # OUT
1161
+ ) noexcept nogil:
1162
+ """Compute gradient of loss w.r.t. `raw_prediction` for a single sample.
1163
+
1164
+ The gradient of the multinomial logistic loss with respect to a class k,
1165
+ and for one sample is:
1166
+ grad_k = - sw * (p[k] - (y==k))
1167
+
1168
+ where:
1169
+ p[k] = proba[k] = exp(raw_prediction[k] - logsumexp(raw_prediction))
1170
+ sw = sample_weight
1171
+
1172
+ Parameters
1173
+ ----------
1174
+ y_true : double
1175
+ Observed, true target value.
1176
+ raw_prediction : array of shape (n_classes,)
1177
+ Raw prediction values (in link space).
1178
+ sample_weight : double
1179
+ Sample weight.
1180
+ gradient_out : array of shape (n_classs,)
1181
+ A location into which the gradient is stored.
1182
+
1183
+ Returns
1184
+ -------
1185
+ gradient : double
1186
+ The derivative of the loss function w.r.t. `raw_prediction`.
1187
+ """
1188
+ cdef:
1189
+ int k
1190
+ int n_classes = raw_prediction.shape[0]
1191
+ double_pair max_value_and_sum_exps
1192
+ const floating_in[:, :] raw = raw_prediction[None, :]
1193
+
1194
+ max_value_and_sum_exps = sum_exp_minus_max(0, raw, &gradient_out[0])
1195
+ for k in range(n_classes):
1196
+ # gradient_out[k] = p_k = y_pred_k = prob of class k
1197
+ gradient_out[k] /= max_value_and_sum_exps.val2
1198
+ # gradient_k = (p_k - (y_true == k)) * sw
1199
+ gradient_out[k] = (gradient_out[k] - (y_true == k)) * sample_weight
1200
+
1201
+ def _test_cy_gradient(
1202
+ self,
1203
+ const floating_in[::1] y_true, # IN
1204
+ const floating_in[:, ::1] raw_prediction, # IN
1205
+ const floating_in[::1] sample_weight, # IN
1206
+ ):
1207
+ """For testing only."""
1208
+ cdef:
1209
+ int i, k
1210
+ int n_samples = y_true.shape[0]
1211
+ int n_classes = raw_prediction.shape[1]
1212
+ floating_in [:, ::1] gradient_out
1213
+ gradient = np.empty((n_samples, n_classes), dtype=np.float64)
1214
+ gradient_out = gradient
1215
+
1216
+ for i in range(n_samples):
1217
+ self.cy_gradient(
1218
+ y_true=y_true[i],
1219
+ raw_prediction=raw_prediction[i, :],
1220
+ sample_weight=1.0 if sample_weight is None else sample_weight[i],
1221
+ gradient_out=gradient_out[i, :],
1222
+ )
1223
+ return gradient
1224
+
1151
1225
# Note that we do not assume memory alignment/contiguity of 2d arrays.
1152
1226
# There seems to be little benefit in doing so. Benchmarks proofing the
1153
1227
# opposite are welcome.
@@ -1165,6 +1239,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1165
1239
int n_classes = raw_prediction.shape[1]
1166
1240
floating_in max_value, sum_exps
1167
1241
floating_in* p # temporary buffer
1242
+ double_pair max_value_and_sum_exps
1168
1243
1169
1244
# We assume n_samples > n_classes. In this case having the inner loop
1170
1245
# over n_classes is a good default.
@@ -1176,12 +1251,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1176
1251
with nogil, parallel(num_threads=n_threads):
1177
1252
# Define private buffer variables as each thread might use its
1178
1253
# own.
1179
- p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2 ))
1254
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1180
1255
1181
1256
for i in prange(n_samples, schedule='static'):
1182
- sum_exp_minus_max(i, raw_prediction, p)
1183
- max_value = p[n_classes] # p[-2]
1184
- sum_exps = p[n_classes + 1] # p[-1]
1257
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1258
+ max_value = max_value_and_sum_exps.val1
1259
+ sum_exps = max_value_and_sum_exps.val2
1185
1260
loss_out[i] = log(sum_exps) + max_value
1186
1261
1187
1262
# label encoded y_true
@@ -1191,12 +1266,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1191
1266
free(p)
1192
1267
else:
1193
1268
with nogil, parallel(num_threads=n_threads):
1194
- p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2 ))
1269
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1195
1270
1196
1271
for i in prange(n_samples, schedule='static'):
1197
- sum_exp_minus_max(i, raw_prediction, p)
1198
- max_value = p[n_classes] # p[-2]
1199
- sum_exps = p[n_classes + 1] # p[-1]
1272
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1273
+ max_value = max_value_and_sum_exps.val1
1274
+ sum_exps = max_value_and_sum_exps.val2
1200
1275
loss_out[i] = log(sum_exps) + max_value
1201
1276
1202
1277
# label encoded y_true
@@ -1222,18 +1297,19 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1222
1297
int n_classes = raw_prediction.shape[1]
1223
1298
floating_in max_value, sum_exps
1224
1299
floating_in* p # temporary buffer
1300
+ double_pair max_value_and_sum_exps
1225
1301
1226
1302
if sample_weight is None:
1227
1303
# inner loop over n_classes
1228
1304
with nogil, parallel(num_threads=n_threads):
1229
1305
# Define private buffer variables as each thread might use its
1230
1306
# own.
1231
- p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2 ))
1307
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1232
1308
1233
1309
for i in prange(n_samples, schedule='static'):
1234
- sum_exp_minus_max(i, raw_prediction, p)
1235
- max_value = p[n_classes] # p[-2]
1236
- sum_exps = p[n_classes + 1] # p[-1]
1310
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1311
+ max_value = max_value_and_sum_exps.val1
1312
+ sum_exps = max_value_and_sum_exps.val2
1237
1313
loss_out[i] = log(sum_exps) + max_value
1238
1314
1239
1315
for k in range(n_classes):
@@ -1247,12 +1323,12 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1247
1323
free(p)
1248
1324
else:
1249
1325
with nogil, parallel(num_threads=n_threads):
1250
- p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2 ))
1326
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1251
1327
1252
1328
for i in prange(n_samples, schedule='static'):
1253
- sum_exp_minus_max(i, raw_prediction, p)
1254
- max_value = p[n_classes] # p[-2]
1255
- sum_exps = p[n_classes + 1] # p[-1]
1329
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1330
+ max_value = max_value_and_sum_exps.val1
1331
+ sum_exps = max_value_and_sum_exps.val2
1256
1332
loss_out[i] = log(sum_exps) + max_value
1257
1333
1258
1334
for k in range(n_classes):
@@ -1281,17 +1357,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1281
1357
int n_classes = raw_prediction.shape[1]
1282
1358
floating_in sum_exps
1283
1359
floating_in* p # temporary buffer
1360
+ double_pair max_value_and_sum_exps
1284
1361
1285
1362
if sample_weight is None:
1286
1363
# inner loop over n_classes
1287
1364
with nogil, parallel(num_threads=n_threads):
1288
1365
# Define private buffer variables as each thread might use its
1289
1366
# own.
1290
- p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2 ))
1367
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1291
1368
1292
1369
for i in prange(n_samples, schedule='static'):
1293
- sum_exp_minus_max(i, raw_prediction, p)
1294
- sum_exps = p[n_classes + 1] # p[-1]
1370
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1371
+ sum_exps = max_value_and_sum_exps.val2
1295
1372
1296
1373
for k in range(n_classes):
1297
1374
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
@@ -1301,11 +1378,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1301
1378
free(p)
1302
1379
else:
1303
1380
with nogil, parallel(num_threads=n_threads):
1304
- p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2 ))
1381
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1305
1382
1306
1383
for i in prange(n_samples, schedule='static'):
1307
- sum_exp_minus_max(i, raw_prediction, p)
1308
- sum_exps = p[n_classes + 1] # p[-1]
1384
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1385
+ sum_exps = max_value_and_sum_exps.val2
1309
1386
1310
1387
for k in range(n_classes):
1311
1388
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
@@ -1329,17 +1406,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1329
1406
int n_classes = raw_prediction.shape[1]
1330
1407
floating_in sum_exps
1331
1408
floating_in* p # temporary buffer
1409
+ double_pair max_value_and_sum_exps
1332
1410
1333
1411
if sample_weight is None:
1334
1412
# inner loop over n_classes
1335
1413
with nogil, parallel(num_threads=n_threads):
1336
1414
# Define private buffer variables as each thread might use its
1337
1415
# own.
1338
- p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2 ))
1416
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1339
1417
1340
1418
for i in prange(n_samples, schedule='static'):
1341
- sum_exp_minus_max(i, raw_prediction, p)
1342
- sum_exps = p[n_classes + 1] # p[-1]
1419
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1420
+ sum_exps = max_value_and_sum_exps.val2
1343
1421
1344
1422
for k in range(n_classes):
1345
1423
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
@@ -1351,11 +1429,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1351
1429
free(p)
1352
1430
else:
1353
1431
with nogil, parallel(num_threads=n_threads):
1354
- p = <floating_in *> mall
F41A
oc(sizeof(floating_in) * (n_classes + 2 ))
1432
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1355
1433
1356
1434
for i in prange(n_samples, schedule='static'):
1357
- sum_exp_minus_max(i, raw_prediction, p)
1358
- sum_exps = p[n_classes + 1] # p[-1]
1435
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1436
+ sum_exps = max_value_and_sum_exps.val2
1359
1437
1360
1438
for k in range(n_classes):
1361
1439
p[k] /= sum_exps # p_k = y_pred_k = prob of class k
@@ -1384,17 +1462,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1384
1462
int n_classes = raw_prediction.shape[1]
1385
1463
floating_in sum_exps
1386
1464
floating_in* p # temporary buffer
1465
+ double_pair max_value_and_sum_exps
1387
1466
1388
1467
if sample_weight is None:
1389
1468
# inner loop over n_classes
1390
1469
with nogil, parallel(num_threads=n_threads):
1391
1470
# Define private buffer variables as each thread might use its
1392
1471
# own.
1393
- p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2 ))
1472
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1394
1473
1395
1474
for i in prange(n_samples, schedule='static'):
1396
- sum_exp_minus_max(i, raw_prediction, p)
1397
- sum_exps = p[n_classes + 1] # p[-1]
1475
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1476
+ sum_exps = max_value_and_sum_exps.val2
1398
1477
1399
1478
for k in range(n_classes):
1400
1479
proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k
@@ -1404,11 +1483,11 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
1404
1483
free(p)
1405
1484
else:
1406
1485
with nogil, parallel(num_threads=n_threads):
1407
- p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2 ))
1486
+ p = <floating_in *> malloc(sizeof(floating_in) * (n_classes))
1408
1487
1409
1488
for i in prange(n_samples, schedule='static'):
1410
- sum_exp_minus_max(i, raw_prediction, p)
1411
- sum_exps = p[n_classes + 1] # p[-1]
1489
+ max_value_and_sum_exps = sum_exp_minus_max(i, raw_prediction, p)
1490
+ sum_exps = max_value_and_sum_exps.val2
1412
1491
1413
1492
for k in range(n_classes):
1414
1493
proba_out[i, k] = p[k] / sum_exps # y_pred_k = prob of class k
0 commit comments