1
1
from typing import Any , Generic , Protocol , Self , TypeAlias , final , type_check_only
2
- from typing_extensions import TypeAliasType , TypeVar
2
+ from typing_extensions import TypeAliasType , TypeVar , TypeVarTuple , override
3
3
4
- from ._shape import Shape , Shape0 , Shape0N , Shape1 , Shape1N , Shape2 , Shape2N , Shape3 , Shape3N , Shape4 , Shape4N
4
+ from ._shape import AnyShape , Shape , Shape0 , Shape1 , Shape1N , Shape2 , Shape2N , Shape3 , Shape3N , Shape4 , Shape4N
5
5
6
6
__all__ = [
7
+ "HasInnerShape" ,
7
8
"HasRankGE" ,
8
9
"HasRankLE" ,
9
10
"Rank" ,
@@ -21,56 +22,71 @@ __all__ = [
21
22
22
23
###
23
24
24
- _Shape00 : TypeAlias = Shape0
25
- _Shape01 : TypeAlias = _Shape00 | Shape1
25
+ _Shape01 : TypeAlias = Shape0 | Shape1
26
26
_Shape02 : TypeAlias = _Shape01 | Shape2
27
27
_Shape03 : TypeAlias = _Shape02 | Shape3
28
28
_Shape04 : TypeAlias = _Shape03 | Shape4
29
29
30
30
###
31
31
32
- _UpperT = TypeVar ("_UpperT" , bound = Shape )
33
- _LowerT = TypeVar ("_LowerT" , bound = Shape )
32
+ # TODO(jorenham): remove `| Rank0 | Rank` once python/mypy#19110 is fixed
33
+ _UpperT = TypeVar ("_UpperT" , bound = Shape | Rank0 | Rank )
34
+ _LowerT = TypeVar ("_LowerT" , bound = Shape | Rank0 | Rank )
34
35
_RankT = TypeVar ("_RankT" , bound = Shape , default = Any )
35
36
37
+ # TODO(jorenham): remove `| Rank0 | Rank` once python/mypy#19110 is fixed
38
+ _RankLE : TypeAlias = _CanBroadcast [Any , _UpperT , _RankT ] | Shape0 | Rank0 | Rank
39
+ # TODO(jorenham): remove `| Rank` once python/mypy#19110 is fixed
40
+ _RankGE : TypeAlias = _CanBroadcast [_LowerT , Any , _RankT ] | _LowerT | Rank
41
+
36
42
HasRankLE = TypeAliasType (
37
43
"HasRankLE" ,
38
- _HasShape [ Shape0 | _HasOwnShape [ _UpperT ] | _CanBroadcast [ Any , _UpperT , _RankT ]],
44
+ _HasInnerShape [ _RankLE [ _UpperT , _RankT ]],
39
45
type_params = (_UpperT , _RankT ),
40
46
)
41
47
HasRankGE = TypeAliasType (
42
48
"HasRankGE" ,
43
- _HasShape [ _LowerT | _CanBroadcast [_LowerT , Any , _RankT ]],
49
+ _HasInnerShape [ _RankGE [_LowerT , _RankT ]],
44
50
type_params = (_LowerT , _RankT ),
45
51
)
46
52
47
- ###
53
+ _ShapeT = TypeVar ( "_ShapeT" , bound = Shape )
48
54
49
- _ShapeT_co = TypeVar ("_ShapeT_co" , bound = Shape | _HasOwnShape | _CanBroadcast , covariant = True )
55
+ # for unwrapping potential rank types as shape tuples
56
+ HasInnerShape = TypeAliasType (
57
+ "HasInnerShape" ,
58
+ _HasInnerShape [_HasOwnShape [Any , _ShapeT ]],
59
+ type_params = (_ShapeT ,),
60
+ )
50
61
51
- @type_check_only
52
- class _HasShape (Protocol [_ShapeT_co ]):
53
- @property
54
- def shape (self , / ) -> _ShapeT_co : ...
62
+ ###
63
+
64
+ _ShapeLikeT_co = TypeVar ("_ShapeLikeT_co" , bound = Shape | _HasOwnShape | _CanBroadcast [Any , Any ], covariant = True )
55
65
56
- _FromT_contra = TypeVar ("_FromT_contra" , default = Any , contravariant = True )
57
- _ToT_contra = TypeVar ("_ToT_contra" , bound = Shape , default = Any , contravariant = True )
66
+ _FromT_contra = TypeVar ("_FromT_contra" , contravariant = True )
67
+ _ToT_contra = TypeVar ("_ToT_contra" , bound = tuple [ Any , ...] , contravariant = True )
58
68
_EquivT_co = TypeVar ("_EquivT_co" , bound = Shape , default = Any , covariant = True )
59
69
70
+ # __broadcast__ is the type-check-only interface order of ranks
60
71
@final
61
72
@type_check_only
62
73
class _CanBroadcast (Protocol [_FromT_contra , _ToT_contra , _EquivT_co ]):
63
74
def __broadcast__ (self , from_ : _FromT_contra , to : _ToT_contra , / ) -> _EquivT_co : ...
64
75
76
+ # __inner_shape__ is similar to `shape`, but directly exposes the `Rank` type.
77
+ @final
78
+ @type_check_only
79
+ class _HasInnerShape (Protocol [_ShapeLikeT_co ]):
80
+ @property
81
+ def __inner_shape__ (self , / ) -> _ShapeLikeT_co : ...
82
+
83
+ _OwnShapeT_contra = TypeVar ("_OwnShapeT_contra" , bound = tuple [Any , ...], default = Any , contravariant = True )
84
+ _OwnShapeT_co = TypeVar ("_OwnShapeT_co" , bound = Shape , default = _OwnShapeT_contra , covariant = True )
85
+
65
86
# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
66
87
# e.g. `_HasOwnShape[Shape2N, Shape0N]` accepts `Shape2N`, `Shape1N`, and `Shape0N`, but
67
88
# rejects `Shape3N` and `Shape1`. Besides brevity, it also works around several mypy bugs that
68
89
# are related to "unions vs joins".
69
-
70
- _OwnShapeT_contra = TypeVar ("_OwnShapeT_contra" , bound = Shape , default = Any , contravariant = True )
71
- _OwnShapeT_co = TypeVar ("_OwnShapeT_co" , bound = Shape , default = _OwnShapeT_contra , covariant = True )
72
- _OwnShapeT = TypeVar ("_OwnShapeT" , bound = tuple [Any , ...], default = Any )
73
-
74
90
@final
75
91
@type_check_only
76
92
class _HasOwnShape (Protocol [_OwnShapeT_contra , _OwnShapeT_co ]):
@@ -79,59 +95,74 @@ class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
79
95
###
80
96
# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
81
97
82
- @type_check_only
83
- class _BaseRank (Generic [_FromT_contra , _OwnShapeT , _ToT_contra ]):
84
- def __broadcast__ (self , from_ : _FromT_contra , to : _ToT_contra , / ) -> Self : ...
85
- def __own_shape__ (self , shape : _OwnShapeT , / ) -> _OwnShapeT : ...
98
+ _Ts = TypeVarTuple ("_Ts" ) # should only contain `int`s
86
99
100
+ # https://github.com/python/mypy/issues/19093
87
101
@type_check_only
88
- class _BaseRankM (
89
- _BaseRank [_FromT_contra | _HasOwnShape [_ToT_contra , Shape ], _OwnShapeT , _ToT_contra ],
90
- Generic [_FromT_contra , _OwnShapeT , _ToT_contra ],
91
- ): ...
102
+ class BaseRank (tuple [* _Ts ], Generic [* _Ts ]):
103
+ def __broadcast__ (self , from_ : tuple [* _Ts ], to : tuple [* _Ts ], / ) -> Self : ...
104
+ def __own_shape__ (self , shape : tuple [* _Ts ], / ) -> tuple [* _Ts ]: ...
92
105
93
106
@final
94
107
@type_check_only
95
- class Rank0 (_BaseRankM [_Shape00 , Shape0 , Shape0N ], tuple [()]): ...
108
+ class Rank0 (BaseRank [()]):
109
+ @override
110
+ def __broadcast__ (self , from_ : Shape0 | _HasOwnShape [Shape , Any ], to : Shape , / ) -> Self : ...
96
111
97
112
@final
98
113
@type_check_only
99
- class Rank1 (_BaseRankM [_Shape01 , Shape1 , Shape1N ], tuple [int ]): ...
114
+ class Rank1 (BaseRank [int ]):
115
+ @override
116
+ def __broadcast__ (self , from_ : _Shape01 | _HasOwnShape [Shape1N , Any ], to : Shape1N , / ) -> Self : ...
100
117
101
118
@final
102
119
@type_check_only
103
- class Rank2 (_BaseRankM [_Shape02 , Shape2 , Shape2N ], tuple [int , int ]): ...
120
+ class Rank2 (BaseRank [int , int ]):
121
+ @override
122
+ def __broadcast__ (self , from_ : _Shape02 | _HasOwnShape [Shape2N , Any ], to : Shape2N , / ) -> Self : ...
104
123
105
124
@final
106
125
@type_check_only
107
- class Rank3 (_BaseRankM [_Shape03 , Shape3 , Shape3N ], tuple [int , int , int ]): ...
126
+ class Rank3 (BaseRank [int , int , int ]):
127
+ @override
128
+ def __broadcast__ (self , from_ : _Shape03 | _HasOwnShape [Shape3N , Any ], to : Shape3N , / ) -> Self : ...
108
129
109
130
@final
110
131
@type_check_only
111
- class Rank4 (_BaseRankM [_Shape04 , Shape4 , Shape4N ], tuple [int , int , int , int ]): ...
132
+ class Rank4 (BaseRank [int , int , int , int ]):
133
+ @override
134
+ def __broadcast__ (self , from_ : _Shape04 | _HasOwnShape [Shape4N , Any ], to : Shape4N , / ) -> Self : ...
112
135
113
- # this emulates `AnyOf`, rather than a `Union`.
114
- @type_check_only
115
- class _BaseRankMToN (_BaseRank [Shape0N , _OwnShapeT , _OwnShapeT ], Generic [_OwnShapeT ]): ...
136
+ # these emulates `AnyOf` (gradual union), rather than a `Union`.
116
137
117
138
@final
118
139
@type_check_only
119
- class Rank (_BaseRankMToN [Shape0N ], tuple [int , ...]): ...
140
+ class Rank (BaseRank [* tuple [int , ...]]):
141
+ @override
142
+ def __broadcast__ (self , from_ : AnyShape , to : tuple [* _Ts ], / ) -> Self : ...
120
143
121
144
@final
122
145
@type_check_only
123
- class Rank1N (_BaseRankMToN [Shape1N ], tuple [int , * tuple [int , ...]]): ...
146
+ class Rank1N (BaseRank [int , * tuple [int , ...]]):
147
+ @override
148
+ def __broadcast__ (self , from_ : AnyShape , to : Shape1N , / ) -> Self : ...
124
149
125
150
@final
126
151
@type_check_only
127
- class Rank2N (_BaseRankMToN [Shape2N ], tuple [int , int , * tuple [int , ...]]): ...
152
+ class Rank2N (BaseRank [int , int , * tuple [int , ...]]):
153
+ @override
154
+ def __broadcast__ (self , from_ : AnyShape , to : Shape2N , / ) -> Self : ...
128
155
129
156
@final
130
157
@type_check_only
131
- class Rank3N (_BaseRankMToN [Shape3N ], tuple [int , int , int , * tuple [int , ...]]): ...
158
+ class Rank3N (BaseRank [int , int , int , * tuple [int , ...]]):
159
+ @override
160
+ def __broadcast__ (self , from_ : AnyShape , to : Shape3N , / ) -> Self : ...
132
161
133
162
@final
134
163
@type_check_only
135
- class Rank4N (_BaseRankMToN [Shape4N ], tuple [int , int , int , int , * tuple [int , ...]]): ...
164
+ class Rank4N (BaseRank [int , int , int , int , * tuple [int , ...]]):
165
+ @override
166
+ def __broadcast__ (self , from_ : AnyShape , to : Shape4N , / ) -> Self : ...
136
167
137
168
Rank0N : TypeAlias = Rank
0 commit comments