-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
ENH: Implementation of the NEP 47 (adopting the array API standard) #18585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
012343d
9934cf3
e00760c
fcff4e1
f36b648
c8efdbb
10427b0
d965102
9578636
e521b16
ba4e21c
4bd5d15
a78d20a
00dda8d
df698f8
be1b193
5df8ec9
ad19f7f
1efd55e
affc5f0
f2ac67e
d9438ad
e233a0a
c791956
853a18d
a42f71a
cd70920
3b9c910
6e36bfc
061fecb
587613f
892b536
b7856e3
2ff635c
b933ebb
73d2c1e
16030e4
d40985c
f1f9dca
63be085
7132764
58c2a99
cdd6bbc
1ccbe68
be45fa1
0ac7de9
9e50716
6c17d4b
48a2d8c
8671a36
da5dc95
e42ae01
ed05662
8968bc3
479c8a2
7ce435c
9fe4fc7
fb5c697
b75a135
d40d2bc
844fcd3
9af1cc6
6c196f5
b0b2539
6115cce
edf68c5
2199687
533d046
4817784
96f40fe
be1ee6c
f6015d2
cad21e9
0178080
1379623
fc1ff6f
aee3a56
4240314
5780a9b
29b7a69
74478e2
60add4a
6379138
5febef5
c5999e2
29974fb
8ca96b2
639aa7c
6765494
5217236
2bdc5c3
49bd660
b58fbd2
c558013
6855a8a
6f98f9e
48df7af
3cab20e
185d06d
d9b9582
56345ff
7c5380d
687e2a3
a566cd1
f20be6a
e34c097
f74b359
4fd028d
9d5d0ec
63a9a87
a16d763
776b117
6265676
1e835f9
64bb971
deaf0bf
e7f6dfe
e4b7205
65ed981
5882962
8680a12
1823e7e
d93aad2
09a4f8c
3b91f47
1596415
6e57d82
ee852b4
7e6a026
c23abdc
5605d68
bc20d33
6789a74
310929d
3730fc0
d74a7d0
fcdadee
b6f71c8
5c7074f
2fe8643
4063752
1ae8084
f13f08f
21923a5
8f7d00e
d5956c1
90537b5
22cb4f3
9978cc5
f12fa6f
06ec0ec
7091e4c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…espace The private function _validate_indices describes the cases that are disallowed. This functionality should be tested (it isn't yet), as the array API test suite will only test the cases that are allowed, not that non-required cases are rejected.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,9 +15,11 @@ | |
|
||
from __future__ import annotations | ||
|
||
import operator | ||
from enum import IntEnum | ||
from ._types import Optional, PyCapsule, Tuple, Union, array | ||
from ._creation_functions import asarray | ||
from ._dtypes import _boolean_dtypes, _integer_dtypes | ||
|
||
import numpy as np | ||
|
||
|
@@ -140,10 +142,120 @@ def __ge__(x1: array, x2: array, /) -> array: | |
res = x1._array.__ge__(asarray(x2)._array) | ||
return x1.__class__._new(res) | ||
|
||
# Note: A large fraction of allowed indices are disallowed here (see the | ||
# docstring below) | ||
@staticmethod | ||
def _validate_index(key, shape): | ||
""" | ||
Validate an index according to the array API. | ||
|
||
The array API specification only requires a subset of indices that are | ||
supported by NumPy. This function will reject any index that is | ||
allowed by NumPy but not required by the array API specification. We | ||
always raise ``IndexError`` on such indices (the spec does not require | ||
any specific behavior on them, but this makes the NumPy array API | ||
namespace a minimal implementation of the spec). | ||
|
||
This function either raises IndexError if the index ``key`` is | ||
invalid, or a new key to be used in place of ``key`` in indexing. It | ||
only raises ``IndexError`` on indices that are not already rejected by | ||
NumPy, as NumPy will already raise the appropriate error on such | ||
indices. ``shape`` may be None, in which case, only cases that are | ||
independent of the array shape are checked. | ||
|
||
The following cases are allowed by NumPy, but not specified by the array | ||
API specification: | ||
|
||
- The start and stop of a slice may not be out of bounds. In | ||
particular, for a slice ``i:j:k`` on an axis of size ``n``, only the | ||
following are allowed: | ||
|
||
- ``i`` or ``j`` omitted (``None``). | ||
- ``-n <= i <= max(0, n - 1)``. | ||
- For ``k > 0`` or ``k`` omitted (``None``), ``-n <= j <= n``. | ||
- For ``k < 0``, ``-n - 1 <= j <= max(0, n - 1)``. | ||
|
||
- Boolean array indices are not allowed as part of a larger tuple | ||
index. | ||
|
||
- Integer array indices are not allowed (with the exception of shape | ||
() arrays, which are treated the same as scalars). | ||
|
||
Additionally, it should be noted that indices that would return a | ||
scalar in NumPy will return a shape () array. Array scalars are not allowed | ||
in the specification, only shape () arrays. This is done in the | ||
``ndarray._new`` constructor, not this function. | ||
|
||
""" | ||
if isinstance(key, slice): | ||
if shape is None: | ||
return key | ||
if shape == (): | ||
return key | ||
size = shape[0] | ||
# Ensure invalid slice entries are passed through. | ||
if key.start is not None: | ||
try: | ||
operator.index(key.start) | ||
except TypeError: | ||
return key | ||
if not (-size <= key.start <= max(0, size - 1)): | ||
raise IndexError("Slices with out-of-bounds start are not allowed in the array API namespace") | ||
if key.stop is not None: | ||
try: | ||
operator.index(key.stop) | ||
except TypeError: | ||
return key | ||
step = 1 if key.step is None else key.step | ||
if (step > 0 and not (-size <= key.stop <= size) | ||
or step < 0 and not (-size - 1 <= key.stop <= max(0, size - 1))): | ||
raise IndexError("Slices with out-of-bounds stop are not allowed in the array API namespace") | ||
return key | ||
|
||
elif isinstance(key, tuple): | ||
key = tuple(ndarray._validate_index(idx, None) for idx in key) | ||
|
||
for idx in key: | ||
if isinstance(idx, np.ndarray) and idx.dtype in _boolean_dtypes or isinstance(idx, (bool, np.bool_)): | ||
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if len(key) == 1: | ||
return key | ||
raise IndexError("Boolean array indices combined with other indices are not allowed in the array API namespace") | ||
|
||
if shape is None: | ||
return key | ||
n_ellipsis = key.count(...) | ||
if n_ellipsis > 1: | ||
return key | ||
ellipsis_i = key.index(...) if n_ellipsis else len(key) | ||
|
||
for idx, size in list(zip(key[:ellipsis_i], shape)) + list(zip(key[:ellipsis_i:-1], shape[:ellipsis_i:-1])): | ||
ndarray._validate_index(idx, (size,)) | ||
return key | ||
6D40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels overly complicated to me. I think you should check for As far as I can tell, that also removes the need for I realize that this is not a specific problem, since you disallow any arrays at this point, but it would avoid this type of thing:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to get some actual tests written for this, and once I do, I'll see if it can be simplified while still passing the tests. |
||
elif isinstance(key, bool): | ||
return key | ||
elif isinstance(key, ndarray): | ||
if key.dtype in _integer_dtypes: | ||
if key.shape != (): | ||
raise IndexError("Integer array indices with shape != () are not allowed in the array API namespace") | ||
return key._array | ||
elif key is Ellipsis: | ||
return key | ||
elif key is None: | ||
raise IndexError("newaxis indices are not allowed in the array API namespace") | ||
try: | ||
return operator.index(key) | ||
except TypeError: | ||
# Note: This also omits boolean arrays that are not already in | ||
# ndarray() form, like a list of booleans. | ||
raise IndexError("Only integers, slices (`:`), ellipsis (`...`), and boolean arrays are valid indices in the array API namespace") | ||
|
||
def __getitem__(x: array, key: Union[int, slice, Tuple[Union[int, slice], ...], array], /) -> array: | ||
""" | ||
Performs the operation __getitem__. | ||
""" | ||
# Note: Only indices required by the spec are allowed. See the | ||
# docstring of _validate_index | ||
key = x._validate_index(key, x.shape) | ||
res = x._array.__getitem__(key) | ||
return x.__class__._new(res) | ||
|
||
|
@@ -266,6 +378,9 @@ def __setitem__(x, key, value, /): | |
""" | ||
Performs the operation __setitem__. | ||
""" | ||
# Note: Only indices required by the spec are allowed. See the | ||
# docstring of _validate_index | ||
key = x._validate_index(key, x.shape) | ||
res = x._array.__setitem__(key, asarray(value)._array) | ||
return x.__class__._new(res) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.