8000 ENH Use Array API in mean_tweedie_deviance (#28106) · scikit-learn/scikit-learn@e12f192 · GitHub
[go: up one dir, main page]

Skip to content

Commit e12f192

Browse files
authored
ENH Use Array API in mean_tweedie_deviance (#28106)
1 parent 1fa3c75 commit e12f192

File tree

5 files changed

+67
-11
lines changed

5 files changed

+67
-11
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ Metrics
106106
-------
107107

108108
- :func:`sklearn.metrics.accuracy_score`
109+
- :func:`sklearn.metrics.mean_tweedie_deviance`
109110
- :func:`sklearn.metrics.r2_score`
110111
- :func:`sklearn.metrics.zero_one_loss`
111112

doc/whats_new/v1.6.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,24 @@ Version 1.6.0
2222

2323
**In Development**
2424

25+
Support for Array API
26+
---------------------
27+
28+
Additional estimators and functions have been updated to include support for all
29+
`Array API <https://data-apis.org/array-api/latest/>`_ compliant inputs.
30+
31+
See :ref:`array_api` for more details.
32+
33+
**Functions:**
34+
35+
- :func:`sklearn.metrics.mean_tweedie_deviance` now supports Array API compatible
36+
inputs.
37+
:pr:`28106` by :user:`Thomas Li <lithomas1>`
38+
39+
**Classes:**
40+
41+
-
42+
2543
Changelog
2644
---------
2745

sklearn/metrics/_regression.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,13 +1276,14 @@ def max_error(y_true, y_pred):
12761276

12771277
def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
12781278
"""Mean Tweedie deviance regression loss."""
1279+
xp, _ = get_namespace(y_true, y_pred)
12791280
p = power
12801281
if p < 0:
12811282
# 'Extreme stable', y any real number, y_pred > 0
12821283
dev = 2 * (
1283-
np.power(np.maximum(y_true, 0), 2 - p) / ((1 - p) * (2 - p))
1284-
- y_true * np.power(y_pred, 1 - p) / (1 - p)
1285-
+ np.power(y_pred, 2 - p) / (2 - p)
1284+
xp.pow(xp.where(y_true > 0, y_true, 0), 2 - p) / ((1 - p) * (2 - p))
1285+
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
1286+
+ xp.pow(y_pred, 2 - p) / (2 - p)
12861287
)
12871288
elif p == 0:
12881289
# Normal distribution, y and y_pred any real number
@@ -1292,15 +1293,14 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
12921293
dev = 2 * (xlogy(y_true, y_true / y_pred) - y_true + y_pred)
12931294
elif p == 2:
12941295
# Gamma distribution
1295-
dev = 2 * (np.log(y_pred / y_true) + y_true / y_pred - 1)
1296+
dev = 2 * (xp.log(y_pred / y_true) + y_true / y_pred - 1)
12961297
else:
12971298
dev = 2 * (
1298-
np.power(y_true, 2 - p) / ((1 - p) * (2 - p))
1299-
- y_true * np.power(y_pred, 1 - p) / (1 - p)
1300-
+ np.power(y_pred, 2 - p) / (2 - p)
1299+
xp.pow(y_true, 2 - p) / ((1 - p) * (2 - p))
1300+
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
1301+
+ xp.pow(y_pred, 2 - p) / (2 - p)
13011302
)
1302-
1303-
return np.average(dev, weights=sample_weight)
1303+
return float(_average(dev, weights=sample_weight))
13041304

13051305

13061306
@validate_params(
@@ -1363,8 +1363,9 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
13631363
>>> mean_tweedie_deviance(y_true, y_pred, power=1)
13641364
1.4260...
13651365
"""
1366+
xp, _ = get_namespace(y_true, y_pred)
13661367
y_type, y_true, y_pred, _ = _check_reg_targets(
1367-
y_true, y_pred, None, dtype=[np.float64, np.float32]
1368+
y_true, y_pred, None, dtype=[xp.float64, xp.float32]
13681369
)
13691370
if y_type == "continuous-multioutput":
13701371
raise ValueError("Multioutput not supported in mean_tweedie_deviance")

sklearn/metrics/tests/test_common.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1824,6 +1824,35 @@ def check_array_api_multiclass_classification_metric(
18241824

18251825

18261826
def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
1827+
y_true_np = np.array([2, 0, 1, 4], dtype=dtype_name)
1828+
y_pred_np = np.array([0.5, 0.5, 2, 2], dtype=dtype_name)
1829+
1830+
check_array_api_metric(
1831+
metric,
1832+
array_namespace,
1833+
device,
1834+
dtype_name,
1835+
y_true_np=y_true_np,
1836+
y_pred_np=y_pred_np,
1837+
sample_weight=None,
1838+
)
1839+
1840+
sample_weight = np.array([0.1, 2.0, 1.5, 0.5], dtype=dtype_name)
1841+
1842+
check_array_api_metric(
1843+
metric,
1844+
array_namespace,
1845+
device,
1846+
dtype_name,
1847+
y_true_np=y_true_np,
1848+
y_pred_np=y_pred_np,
1849+
sample_weight=sample_weight,
1850+
)
1851+
1852+
1853+
def check_array_api_regression_metric_multioutput(
1854+
metric, array_namespace, device, dtype_name
1855+
):
18271856
y_true_np = np.array([[1, 3], [1, 2]], dtype=dtype_name)
18281857
y_pred_np = np.array([[1, 4], [1, 1]], dtype=dtype_name)
18291858

@@ -1859,7 +1888,11 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam
18591888
check_array_api_binary_classification_metric,
18601889
check_array_api_multiclass_classification_metric,
18611890
],
1862-
r2_score: [check_array_api_regression_metric],
1891+
mean_tweedie_deviance: [check_array_api_regression_metric],
1892+
r2_score: [
1893+
check_array_api_regression_metric,
1894+
check_array_api_regression_metric_multioutput,
1895+
],
18631896
}
18641897

18651898

sklearn/utils/_array_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,9 @@ def reshape(self, x, shape, *, copy=None):
427427
def isdtype(self, dtype, kind):
428428
return isdtype(dtype, kind, xp=self)
429429

430+
def pow(self, x1, x2):
431+
return numpy.power(x1, x2)
432+
430433

431434
_NUMPY_API_WRAPPER_INSTANCE = _NumPyAPIWrapper()
432435

0 commit comments

Comments
 (0)
0