8000 Add getitem to array protocol (#8406) · pydata/xarray@0bf38c2 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 0bf38c2

Browse files
Add getitem to array protocol (#8406)
* Update _typing.py * Update _typing.py * Update test_namedarray.py * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update _typing.py * Update _typing.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 562f2f8 commit 0bf38c2

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

xarray/namedarray/_typing.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Default(Enum):
2929
_T = TypeVar("_T")
3030
_T_co = TypeVar("_T_co", covariant=True)
3131

32-
32+
_dtype = np.dtype
3333
_DType = TypeVar("_DType", bound=np.dtype[Any])
3434
_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any])
3535
# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic`
@@ -69,9 +69,16 @@ def dtype(self) -> _DType_co:
6969
_Dims = tuple[_Dim, ...]
7070

7171
_DimsLike = Union[str, Iterable[_Dim]]
72-
_AttrsLike = Union[Mapping[Any, Any], None]
7372

74-
_dtype = np.dtype
73+
# https://data-apis.org/array-api/latest/API_specification/indexing.html
74+
# TODO: np.array_api was bugged and didn't allow (None,), but should!
75+
# https://github.com/numpy/numpy/pull/25022
76+
# https://github.com/data-apis/array-api/pull/674
77+
_IndexKey = Union[int, slice, "ellipsis"]
78+
_IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...]
79+
_IndexKeyLike = Union[_IndexKey, _IndexKeys]
< 10000 /code>
80+
81+
_AttrsLike = Union[Mapping[Any, Any], None]
7582

7683

7784
class _SupportsReal(Protocol[_T_co]):
@@ -113,6 +120,25 @@ class _arrayfunction(
113120
Corresponds to np.ndarray.
114121
"""
115122

123+
@overload
124+
def __getitem__(
125+
self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], /
126+
) -> _arrayfunction[Any, _DType_co]:
127+
...
128+
129+
@overload
130+
def __getitem__(self, key: _IndexKeyLike, /) -> Any:
131+
...
132+
133+
def __getitem__(
134+
self,
135+
key: _IndexKeyLike
136+
| _arrayfunction[Any, Any]
137+
| tuple[_arrayfunction[Any, Any], ...],
138+
/,
139+
) -> _arrayfunction[Any, _DType_co] | Any:
140+
...
141+
116142
@overload
117143
def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]:
118144
...
@@ -165,6 +191,14 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType
165191
Corresponds to np.ndarray.
166192
"""
167193

194+
def __getitem__(
195+
self,
196+
key: _IndexKeyLike
197+
| Any, # TODO: Any should be _arrayapi[Any, _dtype[np.integer]]
198+
/,
199+
) -> _arrayapi[Any, Any]:
200+
...
201+
168202
def __array_namespace__(self) -> ModuleType:
169203
...
170204

xarray/tests/test_namedarray.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_AttrsLike,
2929
_DimsLike,
3030
_DType,
31+
_IndexKeyLike,
3132
_Shape,
3233
duckarray,
3334
)
@@ -58,6 +59,19 @@ class CustomArrayIndexable(
5859
ExplicitlyIndexed,
5960
Generic[_ShapeType_co, _DType_co],
6061
):
62+
def __getitem__(
63+
self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], /
64+
) -> CustomArrayIndexable[Any, _DType_co]:
65+
if isinstance(key, CustomArrayIndexable):
66+
if isinstance(key.array, type(self.array)):
67+
# TODO: key.array is duckarray here, can it be narrowed down further?
68+
# an _arrayapi cannot be used on a _arrayfunction for example.
69+
return type(self)(array=self.array[key.array]) # type: ignore[index]
70+
else:
71+
raise TypeError("key must have the same array type as self")
72+
else:
73+
return type(self)(array=self.array[key])
74+
6175
def __array_namespace__(self) -> ModuleType:
6276
return np
6377

0 commit comments

Comments
 (0)
0