8000 Merge pull request #559 from numpy/rank-to-shape · numpy/numtype@f5adce6 · GitHub
[go: up one dir, main page]

Skip to content

Commit f5adce6

Browse files
authored
Merge pull request #559 from numpy/rank-to-shape
2 parents d7dcd8d + c2bee24 commit f5adce6

23 files changed

< A3DB div class="ml-1 text-small text-bold fgColor-success">+325
-349
lines changed

src/_numtype/@test/generated/test_rank.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# @generated 2025-05-14T01:47:22Z with tool/testgen.py
1+
# @generated 2025-05-19T03:25:34Z with tool/testgen.py
22
from typing import Any
33

44
import _numtype as _nt
@@ -31,7 +31,7 @@ r0n_le_r0n: _nt.HasRankLE[_nt.Rank0N] = r0n
3131

3232
r0_ge_s0: _nt.HasRankGE[_nt.Shape0] = r0
3333
r0_ge_r0: _nt.HasRankGE[_nt.Rank0] = r0
34-
r0_ge_s0n: _nt.HasRankGE[_nt.Shape0N] = r0
34+
r0_ge_s0n: _nt.HasRankGE[_nt.Shape0N] = r0 # type: ignore[assignment]
3535
r0_ge_r0n: _nt.HasRankGE[_nt.Rank0N] = r0
3636
r0n_ge_s0: _nt.HasRankGE[_nt.Shape0] = r0n
3737
r0n_ge_r0: _nt.HasRankGE[_nt.Rank0] = r0n
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import assert_type
2+
3+
import _numtype as _nt
4+
5+
# TODO: remove the `# type: ignore`s once python/mypy#19110 is fixed
6+
7+
a0: _nt.Array0D
8+
assert_type(a0.__inner_shape__, _nt.Rank0)
9+
assert_type(a0.shape, _nt.Shape0) # type: ignore[assert-type]
10+
11+
a1: _nt.Array1D
12+
assert_type(a1.__inner_shape__, _nt.Rank1)
13+
assert_type(a1.shape, _nt.Shape1) # type: ignore[assert-type]
14+
15+
a2: _nt.Array2D
16+
assert_type(a2.__inner_shape__, _nt.Rank2)
17+
assert_type(a2.shape, _nt.Shape2) # type: ignore[assert-type]
18+
19+
a3: _nt.Array3D
20+
assert_type(a3.__inner_shape__, _nt.Rank3)
21+
assert_type(a3.shape, _nt.Shape3) # type: ignore[assert-type]
22+
23+
a4: _nt.Array4D
24+
assert_type(a4.__inner_shape__, _nt.Rank4)
25+
assert_type(a4.shape, _nt.Shape4) # type: ignore[assert-type]

src/_numtype/__init__.pyi

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@ from ._array import (
1919
Array2D as Array2D,
2020
Array3D as Array3D,
2121
Array4D as Array4D,
22-
ArrayND as ArrayND,
2322
MArray as MArray,
2423
MArray0D as MArray0D,
2524
MArray1D as MArray1D,
2625
MArray2D as MArray2D,
2726
MArray3D as MArray3D,
28-
MArrayND as MArrayND,
2927
Matrix as Matrix,
3028
StringArray as StringArray,
3129
StringArray0D as StringArray0D,
@@ -86,6 +84,7 @@ from ._nep50 import (
8684
CastsWithScalar as CastsWithScalar,
8785
)
8886
from ._rank import (
87+
HasInnerShape as HasInnerShape,
8988
HasRankGE as HasRankGE,
9089
HasRankLE as HasRankLE,
9190
Rank as Rank,
@@ -141,6 +140,7 @@ from ._scalar_co import (
141140
co_ulong as co_ulong,
142141
)
143142
from ._shape import (
143+
AnyShape as AnyShape,
144144
Shape as Shape,
145145
Shape0 as Shape0,
146146
Shape0N as Shape0N,
@@ -170,7 +170,8 @@ _ToT = TypeVar("_ToT")
170170

171171
@type_check_only
172172
class CanArray0D(Protocol[_ScalarT_co]):
173-
def __array__(self, /) -> np.ndarray[Shape0, np.dtype[_ScalarT_co]]: ...
173+
# TODO: remove `| Rank0` once python/mypy#19110 is fixed
174+
def __array__(self, /) -> np.ndarray[Shape0 | Rank0, np.dtype[_ScalarT_co]]: ...
174175

175176
@type_check_only
176177
class CanArray1D(Protocol[_ScalarT_co]):
@@ -186,11 +187,13 @@ class CanArray3D(Protocol[_ScalarT_co]):
186187

187188
@type_check_only
188189
class CanArrayND(Protocol[_ScalarT_co]):
189-
def __array__(self, /) -> np.ndarray[Shape, np.dtype[_ScalarT_co]]: ...
190+
# TODO: remove `| Rank0` once python/mypy#19110 is fixed
191+
def __array__(self, /) -> np.ndarray[Shape | Rank0, np.dtype[_ScalarT_co]]: ...
190192

191193
@type_check_only
192194
class CanLenArrayND(Protocol[_ScalarT_co]):
193195
def __len__(self, /) -> int: ...
196+
# TODO: remove `| Rank0` once python/mypy#19110 is fixed
194197
def __array__(self, /) -> np.ndarray[Shape, np.dtype[_ScalarT_co]]: ...
195198

196199
@type_check_only

src/_numtype/_array.pyi

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ from typing_extensions import TypeAliasType, TypeVar
55

66
import numpy as np
77

8-
from ._rank import Rank, Rank0, Rank1, Rank2, Rank3, Rank4
9-
from ._shape import Shape
8+
from ._rank import Rank0, Rank1, Rank2, Rank3, Rank4
9+
from ._shape import AnyShape, Shape
1010

1111
__all__ = [
1212
"Array",
@@ -15,13 +15,11 @@ __all__ = [
1515
"Array2D",
1616
"Array3D",
1717
"Array4D",
18-
"ArrayND",
1918
"MArray",
2019
"MArray0D",
2120
"MArray1D",
2221
"MArray2D",
2322
"MArray3D",
24-
"MArrayND",
2523
"Matrix",
2624
"StringArray",
2725
"StringArray0D",
@@ -33,7 +31,8 @@ __all__ = [
3331

3432
###
3533

36-
_RankT = TypeVar("_RankT", bound=Shape, default=Shape)
34+
# TODO: use `Shape` instead of `AnyShape` once python/mypy#19110 is fixed
35+
_RankT = TypeVar("_RankT", bound=AnyShape, default=Shape)
3736
_ScalarT = TypeVar("_ScalarT", bound=np.generic, default=Any)
3837
_NaT = TypeVar("_NaT", default=Never)
3938

@@ -45,7 +44,6 @@ Array1D = TypeAliasType("Array1D", np.ndarray[Rank1, np.dtype[_ScalarT]], type_p
4544
Array2D = TypeAliasType("Array2D", np.ndarray[Rank2, np.dtype[_ScalarT]], type_params=(_ScalarT,))
4645
Array3D = TypeAliasType("Array3D", np.ndarray[Rank3, np.dtype[_ScalarT]], type_params=(_ScalarT,))
4746
Array4D = TypeAliasType("Array4D", np.ndarray[Rank4, np.dtype[_ScalarT]], type_params=(_ScalarT,))
48-
ArrayND = TypeAliasType("ArrayND", np.ndarray[Rank, np.dtype[_ScalarT]], type_params=(_ScalarT,))
4947

5048
###
5149

@@ -58,7 +56,6 @@ MArray0D = TypeAliasType("MArray0D", np.ma.MaskedArray[Rank0, np.dtype[_ScalarT]
5856
MArray1D = TypeAliasType("MArray1D", np.ma.MaskedArray[Rank1, np.dtype[_ScalarT]], type_params=(_ScalarT,))
5957
MArray2D = TypeAliasType("MArray2D", np.ma.MaskedArray[Rank2, np.dtype[_ScalarT]], type_params=(_ScalarT,))
6058
MArray3D = TypeAliasType("MArray3D", np.ma.MaskedArray[Rank3, np.dtype[_ScalarT]], type_params=(_ScalarT,))
61-
MArrayND = TypeAliasType("MArrayND", np.ma.MaskedArray[Rank, np.dtype[_ScalarT]], type_params=(_ScalarT, _RankT))
6259

6360
###
6461

@@ -89,6 +86,6 @@ StringArray3D = TypeAliasType(
8986
)
9087
StringArrayND = TypeAliasType(
9188
"StringArrayND",
92-
np.ndarray[Rank, np.dtypes.StringDType[_NaT]],
89+
np.ndarray[Shape, np.dtypes.StringDType[_NaT]],
9390
type_params=(_NaT,),
9491
)

src/_numtype/_nep50.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ _ScalarOutT = TypeVar("_ScalarOutT", bound=_ScalarOut, default=Any)
4040
_ScalarOutT_co = TypeVar("_ScalarOutT_co", bound=_ScalarOut, covariant=True)
4141
_ScalarOutT_contra = TypeVar("_ScalarOutT_contra", bound=_ScalarOut, contravariant=True)
4242

43-
_ShapeT = TypeVar("_ShapeT", bound=_shape.Shape, default=_shape.Shape)
43+
_ShapeT = TypeVar("_ShapeT", bound=_shape.Shape, default=Any)
4444
_ShapeT_co = TypeVar("_ShapeT_co", bound=_shape.Shape, covariant=True)
4545

4646
###

src/_numtype/_rank.pyi

Lines changed: 73 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from typing import Any, Generic, Protocol, Self, TypeAlias, final, type_check_only
2-
from typing_extensions import TypeAliasType, TypeVar
2+
from typing_extensions import TypeAliasType, TypeVar, TypeVarTuple, override
33

4-
from ._shape import Shape, Shape0, Shape0N, Shape1, Shape1N, Shape2, Shape2N, Shape3, Shape3N, Shape4, Shape4N
4+
from ._shape import AnyShape, Shape, Shape0, Shape1, Shape1N, Shape2, Shape2N, Shape3, Shape3N, Shape4, Shape4N
55

66
__all__ = [
7+
"HasInnerShape",
78
"HasRankGE",
89
"HasRankLE",
910
"Rank",
@@ -21,56 +22,71 @@ __all__ = [
2122

2223
###
2324

24-
_Shape00: TypeAlias = Shape0
25-
_Shape01: TypeAlias = _Shape00 | Shape1
25+
_Shape01: TypeAlias = Shape0 | Shape1
2626
_Shape02: TypeAlias = _Shape01 | Shape2
2727
_Shape03: TypeAlias = _Shape02 | Shape3
2828
_Shape04: TypeAlias = _Shape03 | Shape4
2929

3030
###
3131

32-
_UpperT = TypeVar("_UpperT", bound=Shape)
33-
_LowerT = TypeVar("_LowerT", bound=Shape)
32+
# TODO(jorenham): remove `| Rank0 | Rank` once python/mypy#19110 is fixed
33+
_UpperT = TypeVar("_UpperT", bound=Shape | Rank0 | Rank)
34+
_LowerT = TypeVar("_LowerT", bound=Shape | Rank0 | Rank)
3435
_RankT = TypeVar("_RankT", bound=Shape, default=Any)
3536

37+
# TODO(jorenham): remove `| Rank0 | Rank` once python/mypy#19110 is fixed
38+
_RankLE: TypeAlias = _CanBroadcast[Any, _UpperT, _RankT] | Shape0 | Rank0 | Rank
39+
# TODO(jorenham): remove `| Rank` once python/mypy#19110 is fixed
40+
_RankGE: TypeAlias = _CanBroadcast[_LowerT, Any, _RankT] | _LowerT | Rank
41+
3642
HasRankLE = TypeAliasType(
3743
"HasRankLE",
38-
_HasShape[Shape0 | _HasOwnShape[_UpperT] | _CanBroadcast[Any, _UpperT, _RankT]],
44+
_HasInnerShape[_RankLE[_UpperT, _RankT]],
3945
type_params=(_UpperT, _RankT),
4046
)
4147
HasRankGE = TypeAliasType(
4248
"HasRankGE",
43-
_HasShape[_LowerT | _CanBroadcast[_LowerT, Any, _RankT]],
49+
_HasInnerShape[_RankGE[_LowerT, _RankT]],
4450
type_params=(_LowerT, _RankT),
4551
)
4652

47-
###
53+
_ShapeT = TypeVar("_ShapeT", bound=Shape)
4854

49-
_ShapeT_co = TypeVar("_ShapeT_co", bound=Shape | _HasOwnShape | _CanBroadcast, covariant=True)
55+
# for unwrapping potential rank types as shape tuples
56+
HasInnerShape = TypeAliasType(
57+
"HasInnerShape",
58+
_HasInnerShape[_HasOwnShape[Any, _ShapeT]],
59+
type_params=(_ShapeT,),
60+
)
5061

51-
@type_check_only
52-
class _HasShape(Protocol[_ShapeT_co]):
53-
@property
54-
def shape(self, /) -> _ShapeT_co: ...
62+
###
63+
64+
_ShapeLikeT_co = TypeVar("_ShapeLikeT_co", bound=Shape | _HasOwnShape | _CanBroadcast[Any, Any], covariant=True)
5565

56-
_FromT_contra = TypeVar("_FromT_contra", default=Any, contravariant=True)
57-
_ToT_contra = TypeVar("_ToT_contra", bound=Shape, default=Any, contravariant=True)
66+
_FromT_contra = TypeVar("_FromT_contra", contravariant=True)
67+
_ToT_contra = TypeVar("_ToT_contra", bound=tuple[Any, ...], contravariant=True)
5868
_EquivT_co = TypeVar("_EquivT_co", bound=Shape, default=Any, covariant=True)
5969

70+
# __broadcast__ is the type-check-only interface order of ranks
6071
@final
6172
@type_check_only
6273
class _CanBroadcast(Protocol[_FromT_contra, _ToT_contra, _EquivT_co]):
6374
def __broadcast__(self, from_: _FromT_contra, to: _ToT_contra, /) -> _EquivT_co: ...
6475

76+
# __inner_shape__ is similar to `shape`, but directly exposes the `Rank` type.
77+
@final
78+
@type_check_only
79+
class _HasInnerShape(Protocol[_ShapeLikeT_co]):
80+
@property
81+
def __inner_shape__(self, /) -> _ShapeLikeT_co: ...
82+
83+
_OwnShapeT_contra = TypeVar("_OwnShapeT_contra", bound=tuple[Any, ...], default=Any, contravariant=True)
84+
_OwnShapeT_co = TypeVar("_OwnShapeT_co", bound=Shape, default=_OwnShapeT_contra, covariant=True)
85+
6586
# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
6687
# e.g. `_HasOwnShape[Shape2N, Shape0N]` accepts `Shape2N`, `Shape1N`, and `Shape0N`, but
6788
# rejects `Shape3N` and `Shape1`. Besides brevity, it also works around several mypy bugs that
6889
# are related to "unions vs joins".
69-
70-
_OwnShapeT_contra = TypeVar("_OwnShapeT_contra", bound=Shape, default=Any, contravariant=True)
71-
_OwnShapeT_co = TypeVar("_OwnShapeT_co", bound=Shape, default=_OwnShapeT_contra, covariant=True)
72-
_OwnShapeT = TypeVar("_OwnShapeT", bound=tuple[Any, ...], default=Any)
73-
7490
@final
7591
@type_check_only
7692
class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
@@ -79,59 +95,74 @@ class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
7995
###
8096
# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
8197

82-
@type_check_only
83-
class _BaseRank(Generic[_FromT_contra, _OwnShapeT, _ToT_contra]):
84-
def __broadcast__(self, from_: _FromT_contra, to: _ToT_contra, /) -> Self: ...
85-
def __own_shape__(self, shape: _OwnShapeT, /) -> _OwnShapeT: ...
98+
_Ts = TypeVarTuple("_Ts") # should only contain `int`s
8699

100+
# https://github.com/python/mypy/issues/19093
87101
@type_check_only
88-
class _BaseRankM(
89-
_BaseRank[_FromT_contra | _HasOwnShape[_ToT_contra, Shape], _OwnShapeT, _ToT_contra],
90-
Generic[_FromT_contra, _OwnShapeT, _ToT_contra],
91-
): ...
102+
class BaseRank(tuple[*_Ts], Generic[*_Ts]):
103+
def __broadcast__(self, from_: tuple[*_Ts], to: tuple[*_Ts], /) -> Self: ...
104+
def __own_shape__(self, shape: tuple[*_Ts], /) -> tuple[*_Ts]: ...
92105

93106
@final
94107
@type_check_only
95-
class Rank0(_BaseRankM[_Shape00, Shape0, Shape0N], tuple[()]): ...
108+
class Rank0(BaseRank[()]):
109+
@override
110+
def __broadcast__(self, from_: Shape0 | _HasOwnShape[Shape, Any], to: Shape, /) -> Self: ...
96111

97112
@final
98113
@type_check_only
99-
class Rank1(_BaseRankM[_Shape01, Shape1, Shape1N], tuple[int]): ...
114+
class Rank1(BaseRank[int]):
115+
@override
116+
def __broadcast__(self, from_: _Shape01 | _HasOwnShape[Shape1N, Any], to: Shape1N, /) -> Self: ...
100117

101118
@final
102119
@type_check_only
103-
class Rank2(_BaseRankM[_Shape02, Shape2, Shape2N], tuple[int, int]): ...
120+
class Rank2(BaseRank[int, int]):
121+
@override
122+
def __broadcast__(self, from_: _Shape02 | _HasOwnShape[Shape2N, Any], to: Shape2N, /) -> Self: ...
104123

105124
@final
106125
@type_check_only
107-
class Rank3(_BaseRankM[_Shape03, Shape3, Shape3N], tuple[int, int, int]): ...
126+
class Rank3(BaseRank[int, int, int]):
127+
@override
128+
def __broadcast__(self, from_: _Shape03 | _HasOwnShape[Shape3N, Any], to: Shape3N, /) -> Self: ...
108129

109130
@final
110131
@type_check_only
111-
class Rank4(_BaseRankM[_Shape04, Shape4, Shape4N], tuple[int, int, int, int]): ...
132+
class Rank4(BaseRank[int, int, int, int]):
133+
@override
134+
def __broadcast__(self, from_: _Shape04 | _HasOwnShape[Shape4N, Any], to: Shape4N, /) -> Self: ...
112135

113-
# this emulates `AnyOf`, rather than a `Union`.
114-
@type_check_only
115-
class _BaseRankMToN(_BaseRank[Shape0N, _OwnShapeT, _OwnShapeT], Generic[_OwnShapeT]): ...
136+
# these emulates `AnyOf` (gradual union), rather than a `Union`.
116137

117138
@final
118139
@type_check_only
119-
class Rank(_BaseRankMToN[Shape0N], tuple[int, ...]): ...
140+
class Rank(BaseRank[*tuple[int, ...]]):
141+
@override
142+
def __broadcast__(self, from_: AnyShape, to: tuple[*_Ts], /) -> Self: ...
120143

121144
@final
122145
@type_check_only
123-
class Rank1N(_BaseRankMToN[Shape1N], tuple[int, *tuple[int, ...]]): ...
146+
class Rank1N(BaseRank[int, *tuple[int, ...]]):
147+
@override
148+
def __broadcast__(self, from_: AnyShape, to: Shape1N, /) -> Self: ...
124149

125150
@final
126151
@type_check_only
127-
class Rank2N(_BaseRankMToN[Shape2N], tuple[int, int, *tuple[int, ...]]): ...
152+
class Rank2N(BaseRank[int, int, *tuple[int, ...]]):
153+
@override
154+
def __broadcast__(self, from_: AnyShape, to: Shape2N, /) -> Self: ...
128155

129156
@final
130157
@type_check_only
131-
class Rank3N(_BaseRankMToN[Shape3N], tuple[int, int, int, *tuple[int, ...]]): ...
158+
class Rank3N(BaseRank[int, int, int, *tuple[int, ...]]):
159+
@override
160+
def __broadcast__(self, from_: AnyShape, to: Shape3N, /) -> Self: ...
132161

133162
@final
134163
@type_check_only
135-
class Rank4N(_BaseRankMToN[Shape4N], tuple[int, int, int, int, *tuple[int, ...]]): ...
164+
class Rank4N(BaseRank[int, int, int, int, *tuple[int, ...]]):
165+
@override
166+
def __broadcast__(self, from_: AnyShape, to: Shape4N, /) -> Self: ...
136167

137168
Rank0N: TypeAlias = Rank

src/_numtype/_shape.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import TypeAlias
1+
from typing import Any, TypeAlias
22
from typing_extensions import TypeAliasType
33

44
__all__ = [
5+
"AnyShape",
56
"Shape",
67
"Shape0",
78
"Shape0N",
@@ -16,8 +17,10 @@ __all__ = [
1617
"ShapeN",
1718
]
1819

20+
AnyShape = TypeAliasType("AnyShape", tuple[Any, ...])
1921
Shape = TypeAliasType("Shape", tuple[int, ...])
2022

23+
# TODO: remove `| Rank0` once python/mypy#19110 is fixed
2124
Shape0 = TypeAliasType("Shape0", tuple[()])
2225
Shape1 = TypeAliasType("Shape1", tuple[int])
2326
Shape2 = TypeAliasType("Shape2", tuple[int, int])

0 commit comments

Comments
 (0)
0