@@ -1784,6 +1784,30 @@ def Elem(*args):
1784
1784
1785
1785
Union [Elem , str ] # Nor should this
1786
1786
1787
+ def test_union_of_literals (self ):
1788
+ self .assertEqual (Union [Literal [1 ], Literal [2 ]].__args__ ,
1789
+ (Literal [1 ], Literal [2 ]))
1790
+ self .assertEqual (Union [Literal [1 ], Literal [1 ]],
1791
+ Literal [1 ])
1792
+
1793
+ self .assertEqual (Union [Literal [False ], Literal [0 ]].__args__ ,
1794
+ (Literal [False ], Literal [0 ]))
1795
+ self .assertEqual (Union [Literal [True ], Literal [1 ]].__args__ ,
1796
+ (Literal [True ], Literal [1 ]))
1797
+
1798
+ import enum
1799
+ class Ints (enum .IntEnum ):
1800
+ A = 0
1801
+ B = 1
1802
+
1803
+ self .assertEqual (Union [Literal [Ints .A ], Literal [Ints .B ]].__args__ ,
1804
+ (Literal [Ints .A ], Literal [Ints .B ]))
1805
+
1806
+ self .assertEqual (Union [Literal [0 ], Literal [Ints .A ], Literal [False ]].__args__ ,
1807
+ (Literal [0 ], Literal [Ints .A ], Literal [False ]))
1808
+ self .assertEqual (Union [Literal [1 ], Literal [Ints .B ], Literal [True ]].__args__ ,
1809
+ (Literal [1 ], Literal [Ints .B ], Literal [True ]))
1810
+
1787
1811
1788
1812
class TupleTests (BaseTestCase ):
1789
1813
@@ -2151,6 +2175,13 @@ def test_basics(self):
2151
2175
Literal [Literal [1 , 2 ], Literal [4 , 5 ]]
2152
2176
Literal [b"foo" , u"bar" ]
2153
2177
2178
+ def test_enum (self ):
2179
+ import enum
2180
+ class My (enum .Enum ):
2181
+ A = 'A'
2182
+
2183
+ self .assertEqual (Literal [My .A ].__args__ , (My .A ,))
2184
+
2154
2185
def test_illegal_parameters_do_not_raise_runtime_errors (self ):
2155
2186
# Type checkers should reject these types, but we do not
2156
2187
# raise errors at runtime to maintain maximum flexibility.
@@ -2240,6 +2271,20 @@ def test_flatten(self):
2240
2271
self .assertEqual (l , Literal [1 , 2 , 3 ])
2241
2272
self .assertEqual (l .__args__ , (1 , 2 , 3 ))
2242
2273
2274
+ def test_does_not_flatten_enum (self ):
2275
+ import enum
2276
+ class Ints (enum .IntEnum ):
2277
+ A = 1
2278
+ B = 2
2279
+
2280
+ l = Literal [
2281
+ Literal [Ints .A ],
2282
+ Literal [Ints .B ],
2283
+ Literal [1 ],
2284
+ Literal [2 ],
2285
+ ]
2286
+ self .assertEqual (l .__args__ , (Ints .A , Ints .B , 1 , 2 ))
2287
+
2243
2288
2244
2289
XK = TypeVar ('XK' , str , bytes )
2245
2290
XV = TypeVar ('XV' )
0 commit comments