8000 ENH: Add `__class_getitem__` to `ndarray`, `dtype` and `number` by BvB93 · Pull Request #19879 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: Add __class_getitem__ to ndarray, dtype and number #19879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 25, 2021
Merged
Prev Previous commit
ENH: Add special-casing for complexfloating so that it can take 2 p…
…arameters
  • Loading branch information
BvB93 committed Sep 18, 2021
commit 8c89fef9e677afd3ee7777f242b6a53d3b7dfef4
14 changes: 12 additions & 2 deletions numpy/core/src/multiarray/scalartypes.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -1812,12 +1812,22 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args)

#ifdef Py_GENERICALIASOBJECT_H
Py_ssize_t args_len;
int args_len_expected;

/* complexfloating should take 2 parameters, all others take 1 */
if (PyType_IsSubtype((PyTypeObject *)cls,
&PyComplexFloatingArrType_Type)) {
args_len_expected = 2;
}
else {
args_len_expected = 1;
}

args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
if (args_len != 1) {
if (args_len != args_len_expected) {
return PyErr_Format(PyExc_TypeError,
"Too %s arguments for %s",
args_len > 1 ? "many" : "few",
args_len > args_len_expected ? "many" : "few",
((PyTypeObject *)cls)->tp_name);
}
generic_alias = Py_GenericAlias(cls, args);
Expand Down
8 changes: 6 additions & 2 deletions numpy/core/tests/test_scalar_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,17 @@ class TestClassGetItem:
np.unsignedinteger,
np.signedinteger,
np.floating,
np.complexfloating,
])
def test_abc(self, cls: Type[np.number]) -> None:
alias = cls[Any]
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is cls

def test_abc_complexfloating(self) -> None:
alias = np.complexfloating[Any, Any]
assert isinstance(alias, types.GenericAlias)
assert alias.__origin__ is np.complexfloating

@pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character])
def test_abc_non_numeric(self, cls: Type[np.generic]) -> None:
with pytest.raises(TypeError):
Expand All @@ -174,7 +178,7 @@ def test_subscript_scalar(self) -> None:


@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
@pytest.mark.parametrize("cls", [np.number, np.int64])
@pytest.mark.parametrize("cls", [np.number, np.complexfloating, np.int64])
def test_class_getitem_38(cls: Type[np.number]) -> None:
match = "Type subscription requires python >= 3.9"
with pytest.raises(TypeError, match=match):
Expand Down
0