10000 Merge pull request #19879 from BvB93/cls_getitem · numpy/numpy@ac78192 · GitHub
[go: up one dir, main page]

Skip to content

Commit ac78192

Browse files
authored
Merge pull request #19879 from BvB93/cls_getitem
ENH: Add `__class_getitem__` to `ndarray`, `dtype` and `number`
2 parents 05fcb65 + 8c89fef commit ac78192

File tree

14 files changed

+462
-29
lines changed

14 files changed

+462
-29
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
``ndarray``, ``dtype`` and ``number`` are now runtime-subscriptable
2+
-------------------------------------------------------------------
3+
Mimicking :pep:`585`, the `~numpy.ndarray`, `~numpy.dtype` and `~numpy.number`
4+
classes are now subscriptable for python 3.9 and later.
5+
Consequently, expressions that were previously only allowed in .pyi stub files
6+
or with the help of ``from __future__ import annotations`` are now also legal
7+
during runtime.
8+
9+
.. code-block:: python
10+
11+
>>> import numpy as np
12+
>>> from typing import Any
13+
14+
>>> np.ndarray[Any, np.dtype[np.float64]]
15+
numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]

doc/source/reference/arrays.dtypes.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,3 +562,10 @@ The following methods implement the pickle protocol:
562562

563563
dtype.__reduce__
564564
dtype.__setstate__
565+
566+
Utility method for typing:
567+
568+
.. autosummary::
569+
:toctree: generated/
570+
571+
dtype.__class_getitem__

doc/source/reference/arrays.ndarray.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,10 @@ String representations:
621621

622622
ndarray.__str__
623623
ndarray.__repr__
624+
625+
Utility method for typing:
626+
627+
.. autosummary::
628+
:toctree: generated/
629+
630+
ndarray.__class_getitem__

doc/source/reference/arrays.scalars.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,10 @@ Inexact types
196196
``f16`` prints as ``0.1`` because it is as close to that value as possible,
197197
whereas the other types do not as they have more precision and therefore have
198198
closer values.
199-
199+
200200
Conversely, floating-point scalars of different precisions which approximate
201201
the same decimal value may compare unequal despite printing identically:
202-
202+
203203
>>> f16 = np.float16("0.1")
204204
>>> f32 = np.float32("0.1")
205205
>>> f64 = np.float64("0.1")
@@ -498,6 +498,13 @@ The exceptions to the above rules are given below:
498498
generic.__setstate__
499499
generic.setflags
500500

501+
Utility method for typing:
502+
503+
.. autosummary::
504+
:toctree: generated/
505+
506+
number.__class_getitem__
507+
501508

502509
Defining new types
503510
==================

numpy/__init__.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ from abc import abstractmethod
99
from types import TracebackType, MappingProxyType
1010
from contextlib import ContextDecorator
1111

12+
if sys.version_info >= (3, 9):
13+
from types import GenericAlias
14+
1215
from numpy._pytesttester import PytestTester
1316
from numpy.core.multiarray import flagsobj
1417
from numpy.core._internal import _ctypes
@@ -1052,6 +1055,9 @@ class dtype(Generic[_DTypeScalar_co]):
10521055
copy: bool = ...,
10531056
) -> dtype[object_]: ...
10541057

1058+
if sys.version_info >= (3, 9):
1059+
def __class_getitem__(self, item: Any) -> GenericAlias: ...
1060+
10551061
@overload
10561062
def __getitem__(self: dtype[void], key: List[str]) -> dtype[void]: ...
10571063
@overload
@@ -1661,6 +1667,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
16611667
strides: None | _ShapeLike = ...,
16621668
order: _OrderKACF = ...,
16631669
) -> _ArraySelf: ...
1670+
1671+
if sys.version_info >= (3, 9):
1672+
def __class_getitem__(self, item: Any) -> GenericAlias: ...
1673+
16641674
@overload
16651675
def __array__(self, dtype: None = ..., /) -> ndarray[Any, _DType_co]: ...
16661676
@overload
@@ -2850,6 +2860,8 @@ class number(generic, Generic[_NBit1]): # type: ignore
28502860
def real(self: _ArraySelf) -> _ArraySelf: ...
28512861
@property
28522862
def imag(self: _ArraySelf) -> _ArraySelf: ...
2863+
if sys.version_info >= (3, 9):
2864+
def __class_getitem__(self, item: Any) -> GenericAlias: ...
28532865
def __int__(self) -> int: ...
28542866
def __float__(self) -> float: ...
28552867
def __complex__(self) -> complex: ...

numpy/core/_add_newdocs.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@
796796
object : array_like
797797
An array, any object exposing the array interface, an object whose
798798
__array__ method returns an array, or any (nested) sequence.
799-
If object is a scalar, a 0-dimensional array containing object is
799+
If object is a scalar, a 0-dimensional array containing object is
800800
returned.
801801
dtype : data-type, optional
802802
The desired data-type for the array. If not given, then the type will
@@ -2201,8 +2201,8 @@
22012201
empty : Create an array, but leave its allocated memory unchanged (i.e.,
22022202
it contains "garbage").
22032203
dtype : Create a data-type.
2204-
numpy.typing.NDArray : A :term:`generic <generic type>` version
2205-
of ndarray.
2204+
numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
2205+
w.r.t. its `dtype.type <numpy.dtype.type>`.
22062206
22072207
Notes
22082208
-----
@@ -2798,6 +2798,39 @@
27982798
"""))
27992799

28002800

2801+
add_newdoc('numpy.core.multiarray', 'ndarray', ('__class_getitem__',
2802+
"""a.__class_getitem__(item, /)
2803+
2804+
Return a parametrized wrapper around the `~numpy.ndarray` type.
2805+
2806+
.. versionadded:: 1.22
2807+
2808+
Returns
2809+
-------
2810+
alias : types.GenericAlias
2811+
A parametrized `~numpy.ndarray` type.
2812+
2813+
Examples
2814+
--------
2815+
>>> from typing import Any
2816+
>>> import numpy as np
2817+
2818+
>>> np.ndarray[Any, np.dtype[Any]]
2819+
numpy.ndarray[typing.Any, numpy.dtype[Any]]
2820+
2821+
Note
2822+
----
2823+
This method is only available for python 3.9 and later.
2824+
2825+
See Also
2826+
--------
2827+
:pep:`585` : Type hinting generics in standard collections.
2828+
numpy.typing.NDArray : An ndarray alias :term:`generic <generic type>`
2829+
w.r.t. its `dtype.type <numpy.dtype.type>`.
2830+
2831+
"""))
2832+
2833+
28012834
add_newdoc('numpy.core.multiarray', 'ndarray', ('__deepcopy__',
28022835
"""a.__deepcopy__(memo, /) -> Deep copy of array.
28032836
@@ -6044,6 +6077,35 @@
60446077
60456078
"""))
60466079

6080+
add_newdoc('numpy.core.multiarray', 'dtype', ('__class_getitem__',
6081+
"""
6082+
__class_getitem__(item, /)
6083+
6084+
Return a parametrized wrapper around the `~numpy.dtype` type.
6085+
6086+
.. versionadded:: 1.22
6087+
6088+
Returns
6089+
-------
6090+
alias : types.GenericAlias
6091+
A parametrized `~numpy.dtype` type.
6092+
6093+
Examples
6094+
--------
6095+
>>> import numpy as np
6096+
6097+
>>> np.dtype[np.int64]
6098+
numpy.dtype[numpy.int64]
6099+
6100+
Note
6101+
----
6102+
This method is only available for python 3.9 and later.
6103+
6104+
See Also
6105+
--------
6106+
:pep:`585` : Type hinting generics in standard collections.
6107+
6108+
"""))
60476109

60486110
##############################################################################
60496111
#
@@ -6465,6 +6527,36 @@ def refer_to_array_attribute(attr, method=True):
64656527
add_newdoc('numpy.core.numerictypes', 'generic',
64666528
refer_to_array_attribute('view'))
64676529

6530+
add_newdoc('numpy.core.numerictypes', 'number', ('__class_getitem__',
6531+
"""
6532+
__class_getitem__(item, /)
6533+
6534+
Return a parametrized wrapper around the `~numpy.number` type.
6535+
6536+
.. versionadded:: 1.22
6537+
6538+
Returns
6539+
-------
6540+
alias : types.GenericAlias
6541+
A parametrized `~numpy.number` type.
6542+
6543+
Examples
6544+
--------
6545+
>>> from typing import Any
6546+
>>> import numpy as np
6547+
6548+
>>> np.signedinteger[Any]
6549+
numpy.signedinteger[typing.Any]
6550+
6551+
Note
6552+
----
6553+
This method is only available for python 3.9 and later.
6554+
6555+
See Also
6556+
--------
6557+
:pep:`585` : Type hinting generics in standard collections.
6558+
6559+
"""))
64686560

64696561
##############################################################################
64706562
#

numpy/core/src/multiarray/descriptor.c

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ static PyArray_Descr *
257257
_convert_from_tuple(PyObject *obj, int align)
258258
{
259259
if (PyTuple_GET_SIZE(obj) != 2) {
260-
PyErr_Format(PyExc_TypeError,
260+
PyErr_Format(PyExc_TypeError,
261261
"Tuple must have size 2, but has size %zd",
262262
PyTuple_GET_SIZE(obj));
263263
return NULL;
@@ -449,8 +449,8 @@ _convert_from_array_descr(PyObject *obj, int align)
449449
for (int i = 0; i < n; i++) {
450450
PyObject *item = PyList_GET_ITEM(obj, i);
451451
if (!PyTuple_Check(item) || (PyTuple_GET_SIZE(item) < 2)) {
452-
PyErr_Format(PyExc_TypeError,
453-
"Field elements must be 2- or 3-tuples, got '%R'",
452+
PyErr_Format(PyExc_TypeError,
453+
"Field elements must be 2- or 3-tuples, got '%R'",
454454
item);
455455
goto fail;
456456
}
@@ -461,7 +461,7 @@ _convert_from_array_descr(PyObject *obj, int align)
461461
}
462462
else if (PyTuple_Check(name)) {
463463
if (PyTuple_GET_SIZE(name) != 2) {
464-
PyErr_Format(PyExc_TypeError,
464+
PyErr_Format(PyExc_TypeError,
465465
"If a tuple, the first element of a field tuple must have "
466466
"two elements, not %zd",
467467
PyTuple_GET_SIZE(name));
@@ -475,7 +475,7 @@ _convert_from_array_descr(PyObject *obj, int align)
475475
}
476476
}
477477
else {
478-
PyErr_SetString(PyExc_TypeError,
478+
PyErr_SetString(PyExc_TypeError,
479479
"First element of field tuple is "
480480
"neither a tuple nor str");
481481
goto fail;
@@ -3101,6 +3101,30 @@ arraydescr_newbyteorder(PyArray_Descr *self, PyObject *args)
31013101
return (PyObject *)PyArray_DescrNewByteorder(self, endian);
31023102
}
31033103

3104+
static PyObject *
3105+
arraydescr_class_getitem(PyObject *cls, PyObject *args)
3106+
{
3107+
PyObject *generic_alias;
3108+
3109+
#ifdef Py_GENERICALIASOBJECT_H
3110+
Py_ssize_t args_len;
3111+
3112+
args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
3113+
if (args_len != 1) {
3114+
return PyErr_Format(PyExc_TypeError,
3115+
"Too %s arguments for %s",
3116+
args_len > 1 ? "many" : "few",
3117+
((PyTypeObject *)cls)->tp_name);
3118+
}
3119+
generic_alias = Py_GenericAlias(cls, args);
3120+
#else
3121+
PyErr_SetString(PyExc_TypeError,
3122+
"Type subscription requires python >= 3.9");
3123+
generic_alias = NULL;
3124+
#endif
3125+
return generic_alias;
3126+
}
3127+
31043128
static PyMethodDef arraydescr_methods[] = {
31053129
/* for pickling */
31063130
{"__reduce__",
@@ -3112,6 +3136,10 @@ static PyMethodDef arraydescr_methods[] = {
31123136
{"newbyteorder",
31133137
(PyCFunction)arraydescr_newbyteorder,
31143138
METH_VARARGS, NULL},
3139+
/* for typing; requires python >= 3.9 */
3140+
{"__class_getitem__",
3141+
(PyCFunction)arraydescr_class_getitem,
3142+
METH_CLASS | METH_O, NULL},
31153143
{NULL, NULL, 0, NULL} /* sentinel */
31163144
};
31173145

numpy/core/src/multiarray/methods.c

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2699,6 +2699,30 @@ array_complex(PyArrayObject *self, PyObject *NPY_UNUSED(args))
26992699
return c;
27002700
}
27012701

2702+
static PyObject *
2703+
array_class_getitem(PyObject *cls, PyObject *args)
2704+
{
2705+
PyObject *generic_alias;
2706+
2707+
#ifdef Py_GENERICALIASOBJECT_H
2708+
Py_ssize_t args_len;
2709+
2710+
args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
2711+
if (args_len != 2) {
2712+
return PyErr_Format(PyExc_TypeError,
2713+
"Too %s arguments for %s",
2714+
args_len > 2 ? "many" : "few",
2715+
((PyTypeObject *)cls)->tp_name);
2716+
}
2717+
generic_alias = Py_GenericAlias(cls, args);
2718+
#else
2719+
PyErr_SetString(PyExc_TypeError,
2720+
"Type subscription requires python >= 3.9");
2721+
generic_alias = NULL;
2722+
#endif
2723+
return generic_alias;
2724+
}
2725+
27022726
NPY_NO_EXPORT PyMethodDef array_methods[] = {
27032727

27042728
/* for subtypes */
@@ -2756,6 +2780,11 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
27562780
(PyCFunction) array_format,
27572781
METH_VARARGS, NULL},
27582782

2783+
/* for typing; requires python >= 3.9 */
2784+
{"__class_getitem__",
2785+
(PyCFunction)array_class_getitem,
2786+
METH_CLASS | METH_O, NULL},
2787+
27592788
/* Original and Extended methods added 2005 */
27602789
{"all",
27612790
(PyCFunction)array_all,

0 commit comments

Comments
 (0)
0