8000 Fix tests for numpy 2 and array api compat by ogrisel · Pull Request #29436 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Fix tests for numpy 2 and array api compat #29436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
8000
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ dependencies:
- numpydoc
- lightgbm
- scikit-image
- array-api-compat
- array-api-strict
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Generated by conda-lock.
# platform: linux-64
# input_hash: af52e4ce613b7668e1e28daaea07461722275d345395a5eaced4e07a16998179
# input_hash: 11d97b96088b6b1eaf3b774050152e7899f0a6ab757350df2efd44b2de3a5f75
@EXPLICIT
https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda#c3473ff8bdb3d124ed5ff11ec380d6f9
https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2024.3.11-h06a4308_0.conda#08529eb3504712baabcbda266a19feb7
Expand All @@ -24,6 +24,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/setuptools-69.5.1-py39h06a4308_0.co
https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.43.0-py39h06a4308_0.conda#40bb60408c7433d767fd8c65b35bc4a0
https://repo.anaconda.com/pkgs/main/linux-64/pip-24.0-py39h06a4308_0.conda#7f8ce3af15cfecd12e4dda8c5cef5fb7
# pip alabaster @ https://files.pythonhosted.org/packages/32/34/d4e1c02d3bee589efb5dfa17f88ea08bdb3e3eac12bc475462aec52ed223/alabaster-0.7.16-py3-none-any.whl#sha256=b46733c07dce03ae4e150330b975c75737fa60f0a7c591b6c8bf4928a28e2c92
# pip array-api-compat @ https://files.pythonhosted.org/packages/05/ae/2f11031bb9f819f6efaaa66b720b37928fbb0087161fcbae3465ae374a18/array_api_compat-1.7.1-py3-none-any.whl#sha256=6974f51775972f39edbca39e08f1c2e43c51401c093a0fea5ac7159875095d8a
# pip babel @ https://files.pythonhosted.org/packages/27/45/377f7e32a5c93d94cd56542349b34efab5ca3f9e2fd5a68c5e93169aa32d/Babel-2.15.0-py3-none-any.whl#sha256=08706bdad8d0a3413266ab61bd6c34d0c28d6e1e7badf40a2cebe67644e2e1fb
# pip certifi @ https://files.pythonhosted.org/packages/1c/d5/c84e1a17bf61d4df64ca866a1c9a913874b4e9bdc131ec689a0ad013fb36/certifi-2024.7.4-py3-none-any.whl#sha256=c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90
# pip charset-normalizer @ https://files.pythonhosted.org/packages/98/69/5d8751b4b670d623aa7a47bef061d69c279e9f922f6705147983aa76c3ce/charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796
Expand Down Expand Up @@ -63,6 +64,7 @@ https://repo.anaconda.com/pkgs/main/linux-64/pip-24.0-py39h06a4308_0.conda#7f8ce
# pip tzdata @ https://files.pythonhosted.org/packages/65/58/f9c9e6be752e9fcb8b6a0ee9fb87e6e7a1f6bcab2cdc73f02bb7ba91ada0/tzdata-2024.1-py2.py3-none-any.whl#sha256=9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252
# pip urllib3 @ https://files.pythonhosted.org/packages/ca/1c/89ffc63a9605b583d5df2be791a27bc1a42b7c32bab68d3c8f2f73a98cd4/urllib3-2.2.2-py3-none-any.whl#sha256=a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472
# pip zipp @ https://files.pythonhosted.org/packages/20/38/f5c473fe9b90c8debdd29ea68d5add0289f1936d6f923b6b9cc0b931194c/zipp-3.19.2-py3-none-any.whl#sha256=f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c
# pip array-api-strict @ https://files.pythonhosted.org/packages/08/06/aba69bce257fd1cda0d1db616c12728af0f46878a5cc1923fcbb94201947/array_api_strict-2.0.1-py3-none-any.whl#sha256=f74cbf0d0c182fcb45c5ee7f28f9c7b77e6281610dfbbdd63be60b1a5a7872b3
# pip contourpy @ https://files.pythonhosted.org/packages/31/a2/2f12e3a6e45935ff694654b710961b03310b0e1ec997ee9f416d3c873f87/contourpy-1.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=e1d59258c3c67c865435d8fbeb35f8c59b8bef3d6f46c1f29f6123556af28445
# pip coverage @ https://files.pythonhosted.org/packages/c4/b4/0cbc18998613f8caaec793ad5878d2450382dfac80e65d352fb7cd9cc1dc/coverage-7.5.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=dbc5958cb471e5a5af41b0ddaea96a37e74ed289535e8deca404811f6cb0bc3d
# pip imageio @ https://files.pythonhosted.org/packages/3d/84/f1647217231f6cc46883e5d26e870cc3e1520d458ecd52d6df750810d53c/imageio-2.34.2-py3-none-any.whl#sha256=a0bb27ec9d5bab36a9f4835e51b21d2cb099e1f78451441f94687ff3404b79f8
Expand Down
5 changes: 5 additions & 0 deletions build_tools/update_environments_and_lock_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,14 @@ def remove_from(alist, to_remove):
"pip_dependencies": (
remove_from(common_dependencies, ["python", "blas", "pip"])
+ docstring_test_dependencies
# Test with some optional dependencies
+ ["lightgbm", "scikit-image"]
# Test array API on CPU without PyTorch
+ ["array-api-compat", "array-api-strict"]
),
"package_constraints": {
# XXX: we would like to use the latest version of Python but this makes
# the CI much slower. We need to investigate why.
"python": "3.9",
},
},
Expand Down
21 changes: 10 additions & 11 deletions sklearn/utils/_array_api.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -672,16 +672,10 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None):
f"weights {tuple(weights.shape)} differ."
)

if weights.ndim != 1:
raise TypeError(
f"1D weights expected when a.shape={tuple(a.shape)} and "
f"weights.shape={tuple(weights.shape)} differ."
)

if size(weights) != a.shape[axis]:
if tuple(weights.shape) != (a.shape[axis],):
raise ValueError(
f"Length of weights {size(weights)} not compatible with "
f" a.shape={tuple(a.shape)} and {axis=}."
f"Shape of weights weights.shape={tuple(weights.shape)} must be "
f"consistent with a.shape={tuple(a.shape)} and {axis=}."
)

# If weights are 1D, add singleton dimensions for broadcasting
Expand Down Expand Up @@ -839,9 +833,14 @@ def _estimator_with_converted_arrays(estimator, converter):
return new_estimator


def _atol_for_type(dtype):
def _atol_for_type(dtype_or_dtype_name):
"""Return the absolute tolerance for a given numpy dtype."""
return numpy.finfo(dtype).eps * 100
if dtype_or_dtype_name is None:
# If no dtype is specified when running tests for a given namespace, we
# expect the same floating precision level as NumPy's default floating
# point dtype.
dtype_or_dtype_name = numpy.float64
return numpy.finfo(dtype_or_dtype_name).eps * 100


def indexing_dtype(xp):
Expand Down
47 changes: 30 additions & 17 deletions sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
assert_array_equal,
skip_if_array_api_compat_not_configured,
)
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS, np_version, parse_version


@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]])
Expand Down Expand Up @@ -67,7 +67,12 @@ def test_get_namespace_ndarray_with_dispatch():
with config_context(array_api_dispatch=True):
xp_out, is_array_api_compliant = get_namespace(X_np)
assert is_array_api_compliant
assert xp_out is array_api_compat.numpy
if np_version >= parse_version("2.0.0"):
# NumPy 2.0+ is an array API compliant library.
assert xp_out is numpy
else:
# Older NumPy versions require the compatibility layer.
assert xp_out is array_api_compat.numpy


@skip_if_array_api_compat_not_configured
Expand Down Expand Up @@ -135,7 +140,7 @@ def test_asarray_with_order_ignored():


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
)
@pytest.mark.parametrize(
"weights, axis, normalize, expected",
Expand Down Expand Up @@ -167,19 +172,22 @@ def test_asarray_with_order_ignored():
],
)
def test_average(
array_namespace, device, dtype_name, weights, axis, normalize, expected
array_namespace, device_, dtype_name, weights, axis, normalize, expected
):
xp = _array_api_for_tests(array_namespace, device)
6D40 xp = _array_api_for_tests(array_namespace, device_)
array_in = numpy.asarray([[1, 2, 3], [4, 5, 6]], dtype=dtype_name)
array_in = xp.asarray(array_in, device=device)
array_in = xp.asarray(array_in, device=device_)
if weights is not None:
weights = numpy.asarray(weights, dtype=dtype_name)
weights = xp.asarray(weights, device=device)
weights = xp.asarray(weights, device=device_)

with config_context(array_api_dispatch=True):
result = _average(array_in, axis=axis, weights=weights, normalize=normalize)

assert getattr(array_in, "device", None) == getattr(result, "device", None)
if np_version < parse_version("2.0.0") or np_version >= parse_version("2.1.0"):
# NumPy 2.0 has a problem with the device attribute of scalar arrays:
# https://github.com/numpy/numpy/issues/26850
assert device(array_in) == device(result)

result = _convert_to_numpy(result, xp)
assert_allclose(result, expected, atol=_atol_for_type(dtype_name))
Expand Down Expand Up @@ -226,14 +234,15 @@ def test_average_raises_with_wrong_dtype(array_namespace, device, dtype_name):
(
0,
[[1, 2]],
TypeError,
"1D weights expected",
# NumPy 2 raises ValueError, NumPy 1 raises TypeError
(ValueError, TypeError),
"weights", # the message is different for NumPy 1 and 2...
),
(
0,
[1, 2, 3, 4],
ValueError,
"Length of weights",
"weights",
),
(0, [-1, 1], ZeroDivisionError, "Weights sum to zero, can't be normalized"),
),
Expand Down Expand Up @@ -580,18 +589,18 @@ def test_get_namespace_and_device():


@pytest.mark.parametrize(
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
"array_namespace, device_, dtype_name", yield_namespace_device_dtype_combinations()
)
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
@pytest.mark.parametrize("sample_weight_type", [None, "int", "float"])
def test_count_nonzero(
array_namespace, device, dtype_name, csr_container, axis, sample_weight_type
array_namespace, device_, dtype_name, csr_container, axis, sample_weight_type
):

from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero

xp = _array_api_for_tests(array_namespace, device)
xp = _array_api_for_tests(array_namespace, device_)
array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]])
if sample_weight_type == "int":
sample_weight = numpy.asarray([1, 2, 2, 3, 1])
Expand All @@ -602,12 +611,16 @@ def test_count_nonzero(
expected = sparse_count_nonzero(
csr_container(array), axis=axis, sample_weight=sample_weight
)
array_xp = xp.asarray(array, device=device)
array_xp = xp.asarray(array, device=device_)

with config_context(array_api_dispatch=True):
result = _count_nonzero(
array_xp, xp=xp, device=device, axis=axis, sample_weight=sample_weight
array_xp, xp=xp, device=device_, axis=axis, sample_weight=sample_weight
)

assert_allclose(_convert_to_numpy(result, xp=xp), expected)
assert getattr(array_xp, "device", None) == getattr(result, "device", None)

if np_version < parse_version("2.0.0") or np_version >= parse_version("2.1.0"):
# NumPy 2.0 has a problem with the device attribute of scalar arrays:
# https://github.com/numpy/numpy/issues/26850
assert device(array_xp) == device(result)
0