8000 array API support for mean_poisson_deviance (#29227) · scikit-learn/scikit-learn@68b71b5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 68b71b5

Browse files
EmilyXinyiogriselOmarManzoorlesteve
authored
array API support for mean_poisson_deviance (#29227)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Omar Salman <omar.salman@arbisoft.com> Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent 8133eca commit 68b71b5

File tree

7 files changed

+41
-3
lines changed

7 files changed

+41
-3
lines changed

.github/workflows/cuda-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ jobs:
4242
run: |
4343
source "${HOME}/conda/etc/profile.d/conda.sh"
4444
conda activate sklearn
45-
pytest -k 'array_api'
45+
SCIPY_ARRAY_API=1 pytest -k 'array_api'

azure-pipelines.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ jobs:
138138
# Here we make sure, that they are still run on a regular basis.
139139
${{ if eq(variables['Build.Reason'], 'Schedule') }}:
140140
SKLEARN_SKIP_NETWORK_TESTS: '0'
141+
SCIPY_ARRAY_API: '1'
141142

142143
# Check compilation with Ubuntu 22.04 LTS (Jammy Jellyfish) and scipy from conda-forge
143144
# By default the CI is sequential, where `Ubuntu_Jammy_Jellyfish` runs first and
@@ -221,6 +222,7 @@ jobs:
221222
# makes sure that they are single threaded in each xdist subprocess.
222223
PYTEST_XDIST_VERSION: 'none'
223224
PIP_BUILD_ISOLATION: 'true'
225+
SCIPY_ARRAY_API: '1'
224226

225227
- template: build_tools/azure/posix-docker.yml
226228
parameters:
@@ -259,6 +261,7 @@ jobs:
259261
DISTRIB: 'conda'
260262
LOCK_FILE: './build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock'
261263
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '5' # non-default seed
264+
SCIPY_ARRAY_API: '1'
262265
pylatest_conda_mkl_no_openmp:
263266
DISTRIB: 'conda'
264267
LOCK_FILE: './build_tools/azure/pylatest_conda_mkl_no_openmp_osx-64_conda.lock'

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ Metrics
119119
- :func:`sklearn.metrics.mean_absolute_error`
120120
- :func:`sklearn.metrics.mean_absolute_percentage_error`
121121
- :func:`sklearn.metrics.mean_gamma_deviance`
122+
- :func:`sklearn.metrics.mean_poisson_deviance` (requires `enabling array API support for SciPy <https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support>`_)
122123
- :func:`sklearn.metrics.mean_squared_error`
123124
- :func:`sklearn.metrics.mean_tweedie_deviance`
124125
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`

doc/whats_new/v1.6.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ See :ref:`array_api` for more details.
3939
and :pr:`29143` by :user:`Tialo <Tialo>` and :user:`Loïc Estève <lesteve>`;
4040
- :func:`sklearn.metrics.mean_absolute_percentage_error` :pr:`29300` by :user:`Emily Chen <EmilyXinyi>`;
4141
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :user:`Emily Chen <EmilyXinyi>`;
42+
- :func:`sklearn.metrics.mean_poisson_deviance` :pr:`29227` by :user:`Emily Chen <EmilyXinyi>`;
4243
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
4344
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
4445
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;

sklearn/metrics/tests/test_common.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
assert_array_less,
7777
ignore_warnings,
7878
)
79-
from sklearn.utils.fixes import COO_CONTAINERS
79+
from sklearn.utils.fixes import COO_CONTAINERS, parse_version, sp_version
8080
from sklearn.utils.multiclass import type_of_target
8181
from sklearn.utils.validation import _num_samples, check_random_state
8282

@@ -1867,6 +1867,12 @@ def check_array_api_multilabel_classification_metric(
18671867

18681868

18691869
def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
1870+
func_name = metric.func.__name__ if isinstance(metric, partial) else metric.__name__
1871+
if func_name == "mean_poisson_deviance" and sp_version < parse_version("1.14.0"):
1872+
pytest.skip(
1873+
"mean_poisson_deviance's dependency `xlogy` is available as of scipy 1.14.0"
1874+
)
1875+
18701876
y_true_np = np.array([2.0, 0.1, 1.0, 4.0], dtype=dtype_name)
18711877
y_pred_np = np.array([0.5, 0.5, 2, 2], dtype=dtype_name)
18721878

@@ -2012,6 +2018,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
20122018
check_array_api_regression_metric,
20132019
],
20142020
paired_cosine_distances: [check_array_api_metric_pairwise],
2021+
mean_poisson_deviance: [check_array_api_regression_metric],
20152022
additive_chi2_kernel: [check_array_api_metric_pairwise],
20162023
mean_gamma_deviance: [check_array_api_regression_metric],
20172024
max_error: [check_array_api_regression_metric],

sklearn/utils/_array_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import itertools
77
import math
8+
import os
9+
import warnings
810
from functools import wraps
911

1012
import numpy
@@ -106,6 +108,17 @@ def _check_array_api_dispatch(array_api_dispatch):
106108
f"NumPy must be {min_numpy_version} or newer to dispatch array using"
107109
" the API specification"
108110
)
111+
if os.environ.get("SCIPY_ARRAY_API") != "1":
112+
warnings.warn(
113+
(
114+
"Some scikit-learn array API features might rely on enabling "
115+
"SciPy's own support for array API to function properly. "
116+
"Please set the SCIPY_ARRAY_API=1 environment variable "
117+
"before importing sklearn or scipy. More details at: "
118+
"https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html"
119+
),
120+
UserWarning,
121+
)
109122

110123

111124
def _single_array_device(array):

sklearn/utils/tests/test_array_api.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import re
23
from functools import partial
34

@@ -77,7 +78,7 @@ def test_get_namespace_ndarray_with_dispatch():
7778

7879

7980
@skip_if_array_api_compat_not_configured
80-
def test_get_namespace_array_api():
81+
def test_get_namespace_array_api(monkeypatch):
8182
"""Test get_namespace for ArrayAPI arrays."""
8283
xp = pytest.importorskip("array_api_strict")
8384

@@ -90,6 +91,18 @@ def test_get_namespace_array_api():
9091
with pytest.raises(TypeError):
9192
xp_out, is_array_api_compliant = get_namespace(X_xp, X_np)
9293

94+
def mock_getenv(key):
95+
if key == "SCIPY_ARRAY_API":
96+
return "0"
97+
98+
monkeypatch.setattr("os.environ.get", mock_getenv)
99+
assert os.environ.get("SCIPY_ARRAY_API") != "1"
100+
with pytest.warns(
101+
UserWarning,
102+
match="enabling SciPy's own support for array API to function properly. ",
103+
):
104+
xp_out, is_array_api_compliant = get_namespace(X_xp)
105+
93106

94107
class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper):
95108
"""API wrapper that has an adjustable name. Used for testing."""

0 commit comments

Comments
 (0)
0