diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index a317059aa1415..a21c032e5d12a 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -5,11 +5,11 @@ import itertools import math -import os import warnings from functools import wraps import numpy +import scipy._lib._array_api import scipy.special as special from .._config import get_config @@ -86,11 +86,19 @@ def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True): yield array_namespace, None, None -def _check_array_api_dispatch(array_api_dispatch): - """Check that array_api_compat is installed and NumPy version is compatible. +def _check_array_api_dispatch(array_api_dispatch, misconfigured_scipy="warn"): + """Check that required dependencies are installed in new enough versions. - array_api_compat follows NEP29, which has a higher minimum NumPy version than - scikit-learn. + We need the array_api_compat package as well as new enough versions of + NumPy and SciPy. + + Parameters + ---------- + array_api_dispatch : bool + Enable or disable array API checks. + + misconfigured_scipy : str, default="warn" + Warn or raise an exception when misconfigured SciPy is detected. """ if array_api_dispatch: try: @@ -108,17 +116,19 @@ def _check_array_api_dispatch(array_api_dispatch): f"NumPy must be {min_numpy_version} or newer to dispatch array using" " the API specification" ) - if os.environ.get("SCIPY_ARRAY_API") != "1": - warnings.warn( - ( - "Some scikit-learn array API features might rely on enabling " - "SciPy's own support for array API to function properly. " - "Please set the SCIPY_ARRAY_API=1 environment variable " - "before importing sklearn or scipy. More details at: " - "https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html" - ), - UserWarning, + + if not scipy._lib._array_api.SCIPY_ARRAY_API: + message = ( + "Some scikit-learn array API features rely on enabling " + "SciPy's own support for array API to function properly. " + "Please set the SCIPY_ARRAY_API=1 environment variable " + "before importing sklearn or scipy. More details at: " + "https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html" ) + if misconfigured_scipy == "raise": + raise RuntimeError(message) + else: + warnings.warn(message, UserWarning) def _single_array_device(array): diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index d75ca9e19cdff..4738a962158f2 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -987,6 +987,7 @@ def fit_transform(self, X, y=None): def _array_api_for_tests(array_namespace, device): + _check_array_api_dispatch(True, misconfigured_scipy="raise") try: array_mod = importlib.import_module(array_namespace) except ModuleNotFoundError: diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 5e3299781a531..4217ceb84c5f9 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -1,9 +1,9 @@ -import os import re from functools import partial import numpy import pytest +import scipy from numpy.testing import assert_allclose from sklearn._config import config_context @@ -91,12 +91,8 @@ def test_get_namespace_array_api(monkeypatch): with pytest.raises(TypeError): xp_out, is_array_api_compliant = get_namespace(X_xp, X_np) - def mock_getenv(key): - if key == "SCIPY_ARRAY_API": - return "0" - - monkeypatch.setattr("os.environ.get", mock_getenv) - assert os.environ.get("SCIPY_ARRAY_API") != "1" + monkeypatch.setattr("scipy._lib._array_api.SCIPY_ARRAY_API", False) + assert not scipy._lib._array_api.SCIPY_ARRAY_API with pytest.warns( UserWarning, match="enabling SciPy's own support for array API to function properly. ",