[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Use Array API in r2_score #27904

Merged
merged 106 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
e0429db
update r2 score to use the array API, and write initial tests
elindgren Aug 18, 2023
b9c1720
Merge remote-tracking branch 'upstream/main'
elindgren Aug 18, 2023
5666ce5
Merge branch 'main' into ENH/r2_score_array_api
ogrisel Aug 19, 2023
4580d1c
Merge branch 'main' into ENH/r2_score_array_api
ogrisel Sep 7, 2023
a4dd594
Fix some review comments and move stuff to CPU
elindgren Sep 8, 2023
adc7680
Add regression tests to the test_common framework
elindgren Sep 28, 2023
85469a9
Update sklearn/metrics/tests/test_regression.py
elindgren Oct 5, 2023
b7efaa5
Update sklearn/metrics/tests/test_regression.py
elindgren Oct 5, 2023
ac533c2
Remove hardcoded device choice in _weighted_sum
betatim Aug 30, 2023
35be22e
Factor out max float precision determination
betatim Sep 7, 2023
7c53e19
Use convenience function to find highest accuracy float in r2_score
elindgren Oct 5, 2023
230ae46
add tests for _average for Array API
elindgren Oct 5, 2023
e4672d1
MNT Ignore ruff errors (#27094)
lesteve Aug 18, 2023
8ba9485
DOC fix docstring for `sklearn.datasets.get_data_home` (#27073)
kachayev Aug 18, 2023
490e0b4
TST Extend tests for `scipy.sparse.*array` in `sklearn/cluster/tests/…
jjerphan Aug 18, 2023
a8a820c
MNT Remove DeprecationWarning for scipy.sparse.linalg.cg tol vs rtol …
lesteve Aug 18, 2023
552e421
Merge branch 'main' into ENH/r2_score_array_api
elindgren Oct 5, 2023
ff52710
Merge remote-tracking branch 'upstream/main' into ENH/r2_score_array_api
elindgren Oct 5, 2023
fe9cc1c
remove temporary file
elindgren Oct 5, 2023
93257ba
WIP: solving dtype and device maze
fcharras Dec 5, 2023
45bbe4e
Fix changelog conflict
fcharras Dec 5, 2023
2145a6b
Tests fixups
fcharras Dec 6, 2023
bd4b224
Tests fixups
fcharras Dec 6, 2023
34aceb1
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Dec 6, 2023
56d5308
Fix dtype parameterization in common metric tests
fcharras Dec 6, 2023
75cb3f3
Tests fixups
fcharras Dec 6, 2023
d9fff24
Tests fixups
fcharras Dec 6, 2023
d72137c
Adds lru_cache on device inspection function + user _convert_to_numpy…
fcharras Dec 11, 2023
16ab95f
Adequatly define hash of _ArrayAPIWrapper to avoid wrong equality
fcharras Dec 11, 2023
9862a85
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Dec 19, 2023
143ce54
Remove _weighted_sum and only use _average
fcharras Dec 19, 2023
4e9401b
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Dec 19, 2023
2b095c4
Linting on unrelated diff, pre-commit broken ? + fixes
fcharras Dec 19, 2023
42f5d8d
Merge branch 'main' into ENH/r2_score_array_api
fcharras Dec 26, 2023
ff0b860
re add faster, simpler code branch for _weighted_sum in _classificati…
fcharras Dec 27, 2023
efe36f3
re add faster, simpler code branch for _weighted_sum in _classificati…
fcharras Dec 27, 2023
abb9ee9
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Dec 27, 2023
08f5433
fix
fcharras Dec 28, 2023
38f56af
fix tests with torch+cuda
fcharras Dec 29, 2023
c09a84b
fix tests with torch+cuda
fcharras Dec 29, 2023
13d9bd6
Merge branch 'main' into ENH/r2_score_array_api
fcharras Jan 2, 2024
c32fa92
FIX: always pass xp to _convert_to_numpy calls
ogrisel Jan 3, 2024
1555f8d
FIX also update device_ in case of numpy fallback
ogrisel Jan 3, 2024
fc1b9f1
FIX pass xp to _convert_to_numpy instead of copy=True
ogrisel Jan 3, 2024
1bf557d
Rename _weighted_sum to _weighted_sum_1d to make it explicit that tho…
ogrisel Jan 3, 2024
c41694a
Improve test coverage for _average function + some review changes
fcharras Jan 4, 2024
c71c3ce
fix torch+cuda
fcharras Jan 4, 2024
d2cd3ca
Merge branch 'main' into ENH/r2_score_array_api
fcharras Jan 4, 2024
4be2ac0
Fix docstring formatting
fcharras Jan 4, 2024
29260e1
Fix error for arrays on different devices
fcharras Jan 5, 2024
e47d53c
Merge branch 'main' into ENH/r2_score_array_api
fcharras Jan 5, 2024
ccbc92d
Adapt device inspection function to non hashable device objects
fcharras Jan 5, 2024
b09b653
Merge branch 'ENH/r2_score_array_api' of https://github.com/fcharras/…
fcharras Jan 5, 2024
0b5b550
CI Remove unused mkl_no_coverage lock file (#28061)
lesteve Jan 4, 2024
2266348
Fix device inspection function + adapt test to non-hashable device ob…
fcharras Jan 9, 2024
db22354
Merge branch 'ENH/r2_score_array_api' of https://github.com/fcharras/…
fcharras Jan 9, 2024
fc3b6e9
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Jan 9, 2024
6ff37fb
Apply suggestion
fcharras Jan 9, 2024
3cda292
Fix device inspection test
fcharras Jan 10, 2024
6179f10
Merge branch 'main' into ENH/r2_score_array_api
ogrisel Jan 23, 2024
bcaa3d8
modify changelog
glemaitre Jan 24, 2024
647109c
Merge remote-tracking branch 'origin/main' into pr/fcharras/27904
glemaitre Jan 24, 2024
fc51090
Apply non-controversial suggestions from code review
ogrisel Jan 25, 2024
df14fca
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
fcharras Feb 9, 2024
db8a046
Adress review comments. NB:
fcharras Feb 9, 2024
2c12856
fixup
fcharras Feb 12, 2024
cd53bd6
Factorize array filtering by type for get_namespace and device helpers
ogrisel Mar 5, 2024
0d1c3bf
Do not upcast partial sums to float64 in r2_score
ogrisel Mar 5, 2024
40dd9d1
Skip strings by default and rename private helper
ogrisel Mar 6, 2024
ac07d4c
WIP fixing type promotion logic
ogrisel Mar 6, 2024
9ebc1ff
Merge branch 'main' into ENH/r2_score_array_api
ogrisel Mar 6, 2024
3550daf
Fix use implementation defined default floating dtype
ogrisel Mar 6, 2024
33177dd
Update test and remove non-reachable branch
ogrisel Mar 6, 2024
f35ea45
Fix error message when the default array filter leads to an empty lis…
ogrisel Mar 6, 2024
a67fe45
Use informative error message in _average while starting with the sam…
ogrisel Mar 7, 2024
d583d9e
Factorize floating point type promotion logic
ogrisel Mar 7, 2024
bae25f0
Merge branch 'main' into ENH/r2_score_array_api
ogrisel Mar 7, 2024
41f99d2
Fix adapt dtype matching logic to non-array inputs, prior to the call…
ogrisel Mar 7, 2024
87e4c8d
Simplification
ogrisel Mar 7, 2024
1c2ea78
Remove device-specific dtype support as its no longer needed by r2_score
ogrisel Mar 7, 2024
c2cbd98
More simplifications
ogrisel Mar 7, 2024
6429401
Improve numerical stability by scaling the weights prior to using the…
ogrisel Mar 7, 2024
6636e4c
Fix test_nan_reductions
ogrisel Mar 7, 2024
be5a474
Revert "Improve numerical stability by scaling the weights prior to u…
ogrisel Mar 7, 2024
6a728ac
Fix formatting
ogrisel Mar 7, 2024
1d4c49e
Skip test_average_raises_with_wrong_dtype on cupy
ogrisel Mar 7, 2024
98347c1
Simplify back _isdtype_single
ogrisel Mar 7, 2024
aff4840
Grammar.
ogrisel Mar 7, 2024
ad0a1fb
Need to conver to float explicitly
ogrisel Mar 7, 2024
d596494
Factorize the float conversion into _assemble_r2_explained_variance
ogrisel Mar 7, 2024
d6f0101
Move tuple conversion at the beginning of _skip_non_arrays
ogrisel Mar 7, 2024
ec84e44
Small fixes in comments and remove duplicated lines.
ogrisel Mar 8, 2024
08405a5
One more get_namespace simplification
ogrisel Mar 8, 2024
a09866d
Remove useless import added by vs code...
ogrisel Mar 8, 2024
b59a7be
Apply suggestions from code review
ogrisel Mar 10, 2024
ef1631b
Rename _skip_non_arrays to _remove_non_arrays & co
ogrisel Mar 11, 2024
388d670
Remove custom __hash__ method that is no longer needed
ogrisel Mar 11, 2024
8042795
Remove redundant calls to xp.astype
ogrisel Mar 11, 2024
92af1a8
Factorize the if xp is None: xp, _ = get_namespace(inputs) pattern
ogrisel Mar 11, 2024
47fed64
Fix handling of xp is not None in get_namespace
ogrisel Mar 11, 2024
3699353
get_namespace in _weighted_sum_1d
ogrisel Mar 11, 2024
c2b4b11
Merge _weighted_sum_1d into _average
ogrisel Mar 11, 2024
9c2d9ac
One final 'if xp is None' occurrence
ogrisel Mar 11, 2024
90076d3
DOC be explicit about return types
ogrisel Mar 11, 2024
3cc74e5
Merge branch 'main' into ENH/r2_score_array_api
ogrisel Mar 11, 2024
457531e
Update phrasing in the doc to avoid confusing array container type wi…
ogrisel Mar 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Metrics
-------

- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.zero_one_loss`

Tools
Expand All @@ -115,6 +116,22 @@ Tools
Coverage is expected to grow over time. Please follow the dedicated `meta-issue on GitHub
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress.

Type of return values and fitted attributes
-------------------------------------------

When calling functions or methods with Array API compatible inputs, the
convention is to return array values of the same array container type and
device as the input data.

Similarly, when an estimator is fitted with Array API compatible inputs, the
fitted attributes will be arrays from the same library as the input and stored
on the same device. The `predict` and `transform` method subsequently expect
inputs from the same array library and device as the data passed to the `fit`
method.

Note however that scoring functions that return scalar values return Python
scalars (typically a `float` instance) instead of an array scalar value.

Common estimator checks
=======================

Expand Down
16 changes: 16 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ Version 1.5.0

**In Development**

Support for Array API
---------------------

Additional estimators and functions have been updated to include support for all
`Array API <https://data-apis.org/array-api/latest/>`_ compliant inputs.

See :ref:`array_api` for more details.

**Functions:**

- :func:`sklearn.metrics.r2_score` now supports Array API compliant inputs.
:pr:`27904` by :user:`Eric Lindgren <elindgren>`, `Franck Charras <fcharras>`,
`Olivier Grisel <ogrisel>` and `Tim Head <betatim>`.

**Classes:**

Support for building with Meson
-------------------------------

Expand Down
12 changes: 8 additions & 4 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
check_consistent_length,
column_or_1d,
)
from ..utils._array_api import _union1d, _weighted_sum, get_namespace
from ..utils._array_api import (
_average,
_union1d,
get_namespace,
)
from ..utils._param_validation import (
Hidden,
Interval,
Expand Down Expand Up @@ -224,7 +228,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
else:
score = y_true == y_pred

return _weighted_sum(score, sample_weight, normalize)
return float(_average(score, weights=sample_weight, normalize=normalize))


@validate_params(
Expand Down Expand Up @@ -2809,7 +2813,7 @@ def hamming_loss(y_true, y_pred, *, sample_weight=None):
return n_differences / (y_true.shape[0] * y_true.shape[1] * weight_average)

elif y_type in ["binary", "multiclass"]:
return _weighted_sum(y_true != y_pred, sample_weight, normalize=True)
return float(_average(y_true != y_pred, weights=sample_weight, normalize=True))
else:
raise ValueError("{0} is not supported".format(y_type))

Expand Down Expand Up @@ -2994,7 +2998,7 @@ def log_loss(
y_pred = y_pred / y_pred_sum[:, np.newaxis]
loss = -xlogy(transformed_labels, y_pred).sum(axis=1)

return _weighted_sum(loss, sample_weight, normalize)
return float(_average(loss, weights=sample_weight, normalize=normalize))


@validate_params(
Expand Down
56 changes: 41 additions & 15 deletions sklearn/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
from scipy.special import xlogy

from ..exceptions import UndefinedMetricWarning
from ..utils._array_api import (
_average,
_find_matching_floating_dtype,
device,
get_namespace,
)
from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params
from ..utils.stats import _weighted_percentile
from ..utils.validation import (
Expand Down Expand Up @@ -65,7 +71,7 @@
]


def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric"):
def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
"""Check that y_true and y_pred belong to the same regression task.

Parameters
Expand Down Expand Up @@ -99,15 +105,17 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric"):
just the corresponding argument if ``multioutput`` is a
correct keyword.
"""
xp, _ = get_namespace(y_true, y_pred, multioutput, xp=xp)

check_consistent_length(y_true, y_pred)
y_true = check_array(y_true, ensure_2d=False, dtype=dtype)
y_pred = check_array(y_pred, ensure_2d=False, dtype=dtype)

if y_true.ndim == 1:
y_true = y_true.reshape((-1, 1))
y_true = xp.reshape(y_true, (-1, 1))

if y_pred.ndim == 1:
y_pred = y_pred.reshape((-1, 1))
y_pred = xp.reshape(y_pred, (-1, 1))

if y_true.shape[1] != y_pred.shape[1]:
raise ValueError(
Expand Down Expand Up @@ -855,9 +863,10 @@ def median_absolute_error(


def _assemble_r2_explained_variance(
numerator, denominator, n_outputs, multioutput, force_finite
numerator, denominator, n_outputs, multioutput, force_finite, xp, device
):
"""Common part used by explained variance score and :math:`R^2` score."""
dtype = numerator.dtype

nonzero_denominator = denominator != 0

Expand All @@ -868,12 +877,14 @@ def _assemble_r2_explained_variance(
nonzero_numerator = numerator != 0
# Default = Zero Numerator = perfect predictions. Set to 1.0
# (note: even if denominator is zero, thus avoiding NaN scores)
output_scores = np.ones([n_outputs])
output_scores = xp.ones([n_outputs], device=device, dtype=dtype)
# Non-zero Numerator and Non-zero Denominator: use the formula
valid_score = nonzero_denominator & nonzero_numerator

output_scores[valid_score] = 1 - (
numerator[valid_score] / denominator[valid_score]
)

# Non-zero Numerator and Zero Denominator:
# arbitrary set to 0.0 to avoid -inf scores
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0
Expand All @@ -887,15 +898,18 @@ def _assemble_r2_explained_variance(
avg_weights = None
elif multioutput == "variance_weighted":
avg_weights = denominator
if not np.any(nonzero_denominator):
if not xp.any(nonzero_denominator):
# All weights are zero, np.average would raise a ZeroDiv error.
# This only happens when all y are constant (or 1-element long)
# Since weights are all equal, fall back to uniform weights.
avg_weights = None
else:
avg_weights = multioutput

return np.average(output_scores, weights=avg_weights)
result = _average(output_scores, weights=avg_weights)
if result.size == 1:
return float(result)
betatim marked this conversation as resolved.
Show resolved Hide resolved
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
return result


@validate_params(
Expand Down Expand Up @@ -1033,6 +1047,9 @@ def explained_variance_score(
n_outputs=y_true.shape[1],
multioutput=multioutput,
force_finite=force_finite,
xp=get_namespace(y_true)[0],
# TODO: update once Array API support is added to explained_variance_score.
betatim marked this conversation as resolved.
Show resolved Hide resolved
device=None,
)


Expand Down Expand Up @@ -1177,8 +1194,14 @@ def r2_score(
>>> r2_score(y_true, y_pred, force_finite=False)
-inf
"""
y_type, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput
input_arrays = [y_true, y_pred, sample_weight, multioutput]
xp, _ = get_namespace(*input_arrays)
device_ = device(*input_arrays)

dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)

_, y_true, y_pred, multioutput = _check_reg_targets(
y_true, y_pred, multioutput, dtype=dtype, xp=xp
)
check_consistent_length(y_true, y_pred, sample_weight)

Expand All @@ -1188,22 +1211,25 @@ def r2_score(
return float("nan")

if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
weight = sample_weight[:, np.newaxis]
sample_weight = column_or_1d(sample_weight, dtype=dtype)
weight = sample_weight[:, None]
else:
weight = 1.0

numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype=np.float64)
denominator = (
weight * (y_true - np.average(y_true, axis=0, weights=sample_weight)) ** 2
).sum(axis=0, dtype=np.float64)
numerator = xp.sum(weight * (y_true - y_pred) ** 2, axis=0)
denominator = xp.sum(
weight * (y_true - _average(y_true, axis=0, weights=sample_weight, xp=xp)) ** 2,
axis=0,
)

return _assemble_r2_explained_variance(
numerator=numerator,
denominator=denominator,
n_outputs=y_true.shape[1],
multioutput=multioutput,
force_finite=force_finite,
xp=xp,
device=device_,
)


Expand Down
31 changes: 30 additions & 1 deletion sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from sklearn.utils import shuffle
from sklearn.utils._array_api import (
_atol_for_type,
_convert_to_numpy,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
Expand Down Expand Up @@ -1749,7 +1750,7 @@ def check_array_api_metric(
metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)

assert_allclose(
metric_xp,
_convert_to_numpy(xp.asarray(metric_xp), xp),
metric_np,
atol=_atol_for_type(dtype_name),
)
Expand Down Expand Up @@ -1813,6 +1814,33 @@ def check_array_api_multiclass_classification_metric(
)


def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
y_true_np = np.array([[1, 3], [1, 2]], dtype=dtype_name)
y_pred_np = np.array([[1, 4], [1, 1]], dtype=dtype_name)

check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
sample_weight=None,
)

sample_weight = np.array([0.1, 2.0], dtype=dtype_name)

check_array_api_metric(
metric,
array_namespace,
device,
dtype_name,
y_true_np=y_true_np,
y_pred_np=y_pred_np,
sample_weight=sample_weight,
)


array_api_metric_checkers = {
accuracy_score: [
check_array_api_binary_classification_metric,
Expand All @@ -1822,6 +1850,7 @@ def check_array_api_multiclass_classification_metric(
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
r2_score: [check_array_api_regression_metric],
}


Expand Down
Loading
Loading