8000 BUG: Allow ``None`` as ``api_version`` in ``__array_namespace__`` method by mtsokol · Pull Request #25595 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

BUG: Allow None as api_version in __array_namespace__ method #25595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
FIX: Allow None as api_version in __array_namespace__
  • Loading branch information
mtsokol committed Jan 18, 2024
commit 5275d9b1adf4aa215c50f29a5ec47d90e09f1d91
2 changes: 1 addition & 1 deletion numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2548,7 +2548,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
def __dlpack__(self: NDArray[number[Any]], *, stream: None = ...) -> _PyCapsule: ...
def __dlpack_device__(self) -> tuple[int, L[0]]: ...

def __array_namespace__(self, *, api_version: str = ...) -> Any: ...
def __array_namespace__(self, *, api_version: str | None = ...) -> Any: ...

def bitwise_count(
self,
Expand Down
25 changes: 17 additions & 8 deletions numpy/_core/src/multiarray/methods.c
Original file line number Diff line number Diff line change
Expand Up @@ -2760,19 +2760,28 @@ static PyObject *
array_array_namespace(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"api_version", NULL};
char *array_api_version = "2022.12";
PyObject *array_api_version = Py_None;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "|$s:__array_namespace__", kwlist,
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|$O:__array_namespace__", kwlist,
&array_api_version)) {
return NULL;
}

if (strcmp(array_api_version, "2021.12") != 0 &&
strcmp(array_api_version, "2022.12") != 0) {
PyErr_Format(PyExc_ValueError,
"Version \"%s\" of the Array API Standard is not supported.",
array_api_version);
return NULL;
if (array_api_version != Py_None) {
if (!PyUnicode_Check(array_api_version))
{
PyErr_Format(PyExc_ValueError,
"Only None and strings are allowed as the Array API version, "
"but received: %S.", array_api_version);
return NULL;
} else if (PyUnicode_CompareWithASCIIString(array_api_version, "2021.12") != 0 &&
PyUnicode_CompareWithASCIIString(array_api_version, "2022.12& 8000 quot;) != 0)
{
PyErr_Format(PyExc_ValueError,
"Version \"%U\" of the Array API Standard is not supported.",
array_api_version);
return NULL;
}
}

PyObject *numpy_module = PyImport_ImportModule("numpy");
Expand Down
8 changes: 8 additions & 0 deletions numpy/_core/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,8 @@ def test__array_namespace__(self):
assert xp is np
xp = arr.__array_namespace__(api_version="2022.12")
assert xp is np
xp = arr.__array_namespace__(api_version=None)
assert xp is np

with pytest.raises(
ValueError,
Expand All @@ -2575,6 +2577,12 @@ def test__array_namespace__(self):
):
arr.__array_namespace__(api_version="2023.12")

with pytest.raises(
ValueError,
match="Only None and strings are allowed as the Array API version"
):
arr.__array_namespace__(api_version=2023)

def test_isin_refcnt_bug(self):
# gh-25295
for _ in range(1000):
Expand Down
0