8000 Merge pull request #25595 from mtsokol/fix-array-namespace-none · numpy/numpy@5feea41 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5feea41

Browse files
authored
Merge pull request #25595 from mtsokol/fix-array-namespace-none
BUG: Allow ``None`` as ``api_version`` in ``__array_namespace__`` method
2 parents 174ac7b + 5275d9b commit 5feea41

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

numpy/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2548,7 +2548,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
25482548
def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ...
25492549
def __dlpack_device__(self) -> tuple[int, L[0]]: ...
25502550

2551-
def __array_namespace__(self, *, api_version: str = ...) -> Any: ...
2551+
def __array_namespace__(self, *, api_version: str | None = ...) -> Any: ...
25522552

25532553
def bitwise_count(
25542554
self,

numpy/_core/src/multiarray/methods.c

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2760,19 +2760,28 @@ static PyObject *
27602760
array_array_namespace(PyArrayObject *self, PyObject *args, PyObject *kwds)
27612761
{
27622762
static char *kwlist[] = {"api_version", NULL};
2763-
char *array_api_version = "2022.12";
2763+
PyObject *array_api_version = Py_None;
27642764

2765-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|$s:__array_namespace__", kwlist,
2765+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|$O:__array_namespace__", kwlist,
27662766
&array_api_version)) {
27672767
return NULL;
27682768
}
27692769

2770-
if (strcmp(array_api_version, "2021.12") != 0 &&
2771-
strcmp(array_api_version, "2022.12") != 0) {
2772-
PyErr_Format(PyExc_ValueError,
2773-
"Version \"%s\" of the Array API Standard is not supported.",
2774-
array_api_version);
2775-
return NULL;
2770+
if (array_api_version != Py_None) {
2771+
if (!PyUnicode_Check(array_api_version))
2772+
{
2773+
PyErr_Format(PyExc_ValueError,
2774+
"Only None and strings are allowed as the Array API version, "
2775+
"but received: %S.", array_api_version);
2776+
return NULL;
2777+
} else if (PyUnicode_CompareWithASCIIString(array_api_version, "2021.12") != 0 &&
2778+
PyUnicode_CompareWithASCIIString(array_api_version, "2022.12") != 0)
2779+
{
2780+
PyErr_Format(PyExc_ValueError,
2781+
"Version \"%U\" of the Array API Standard is not supported.",
2782+
array_api_version);
2783+
return NULL;
2784+
}
27762785
}
27772786

27782787
PyObject *numpy_module = PyImport_ImportModule("numpy");

numpy/_core/tests/test_regression.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2567,6 +2567,8 @@ def test__array_namespace__(self):
25672567
assert xp is np
25682568
xp = arr.__array_namespace__(api_version="2022.12")
25692569
assert xp is np
2570+
xp = arr.__array_namespace__(api_version=None)
2571+
assert xp is np
25702572

25712573
with pytest.raises(
25722574
ValueError,
@@ -2575,6 +2577,12 @@ def test__array_namespace__(self):
25752577
):
25762578
arr.__array_namespace__(api_version="2023.12")
25772579

2580+
with pytest.raises(
2581+
ValueError,
2582+
match="Only None and strings are allowed as the Array API version"
2583+
):
2584+
arr.__array_namespace__(api_version=2023)
2585+
25782586
def test_isin_refcnt_bug(self):
25792587
# gh-25295
25802588
for _ in range(1000):

0 commit comments

Comments
 (0)
0