8000 MNT Update array-api-compat to 1.12 (#31388) · scikit-learn/scikit-learn@9b40cbc · GitHub
[go: up one dir, main page]

Skip to content

Commit 9b40cbc

Browse files
authored
MNT Update array-api-compat to 1.12 (#31388)
1 parent ff6bf36 commit 9b40cbc

31 files changed

+1823
-1103
lines changed

maint_tools/vendor_array_api_compat.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set -o nounset
66
set -o errexit
77

88
URL="https://github.com/data-apis/array-api-compat.git"
9-
VERSION="1.11.2"
9+
VERSION="1.12"
1010

1111
ROOT_DIR=sklearn/externals/array_api_compat
1212

sklearn/externals/array_api_compat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
this implementation for the default when working with NumPy arrays.
1818
1919
"""
20-
__version__ = '1.11.2'
20+
__version__ = '1.12.0'
2121

2222
from .common import * # noqa: F401, F403

sklearn/externals/array_api_compat/_internal.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
Internal helpers
33
"""
44

5+
from collections.abc import Callable
56
from functools import wraps
67
from inspect import signature
8+
from types import ModuleType
9+
from typing import TypeVar
710

8-
def get_xp(xp):
11+
_T = TypeVar("_T")
12+
13+
14+
def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
915
"""
1016
Decorator to automatically replace xp with the corresponding array module.
1117
@@ -22,14 +28,14 @@ def func(x, /, xp, kwarg=None):
2228
2329
"""
2430

25-
def inner(f):
31+
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
2632
@wraps(f)
27-
def wrapped_f(*args, **kwargs):
33+
def wrapped_f(*args: object, **kwargs: object) -> object:
2834
return f(*args, xp=xp, **kwargs)
2935

3036
sig = signature(f)
3137
new_sig = sig.replace(
32-
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
38+
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
3339
)
3440

3541
if wrapped_f.__doc__ is None:
@@ -40,7 +46,14 @@ def wrapped_f(*args, **kwargs):
4046
specification for more details.
4147
4248
"""
43-
wrapped_f.__signature__ = new_sig
44-
return wrapped_f
49+
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
50+
return wrapped_f # pyright: ignore[reportReturnType]
4551

4652
return inner
53+
54+
55+
__all__ = ["get_xp"]
56+
57+
58+
def __dir__() -> list[str]:
59+
return __all__
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from ._helpers import * # noqa: F403
1+
from ._helpers import * # noqa: F403

0 commit comments

Comments
 (0)
0