@@ -336,6 +336,41 @@ def is_simple_literal(t: ProperType) -> bool:
336
336
return False
337
337
338
338
339
+ def _get_flattened_proper_types (items : Sequence [Type ]) -> Sequence [ProperType ]:
340
+ """Similar to types.get_proper_types, with flattening of UnionType
341
+
342
+ Optimized to avoid allocating a new list whenever possible"""
343
+ i : int = 0
344
+ n : int = len (items )
345
+
346
+ # optimistic fast path
347
+ while i < n :
348
+ t = items [i ]
349
+ pt = get_proper_type (t )
350
+ if id (t ) != id (pt ) or isinstance (pt , UnionType ):
351
+ # we need to allocate, switch to slow path
352
+ break
353
+ i += 1
354
+
355
+ # optimistic fast path reached end of input, no need to allocate
356
+ if i == n :
357
+ return cast (Sequence [ProperType ], items )
358
+
359
+ all_items = list (cast (Sequence [ProperType ], items [0 :i ]))
360
+
361
+ while i < n :
362
+ pt = get_proper_type (items [i ])
363
+ if isinstance (pt , UnionType ):
364
+ all_items .extend (_get_flattened_proper_types (pt .items ))
365
+ else :
366
+ all_items .append (pt )
367
+ i += 1
368
+ return all_items
369
+
370
+
371
+ _simplified_union_cache : List [Dict [Tuple [ProperType , ...], ProperType ]] = [{} for _ in range (8 )]
372
+
373
+
339
374
def make_simplified_union (items : Sequence [Type ],
340
375
line : int = - 1 , column : int = - 1 ,
341
376
* , keep_erased : bool = False ,
@@ -362,32 +397,47 @@ def make_simplified_union(items: Sequence[Type],
362
397
back into a sum type. Set it to False when called by try_expanding_sum_type_
363
398
to_union().
364
399
"""
365
- items = get_proper_types (items )
366
-
367
400
# Step 1: expand all nested unions
368
- while any (isinstance (typ , UnionType ) for typ in items ):
369
- all_items : List [ProperType ] = []
370
- for typ in items :
371
- if isinstance (typ , UnionType ):
372
- all_items .extend (get_proper_types (typ .items ))
373
- else :
374
- all_items .append (typ )
375
- items = all_items
401
+ items = _get_flattened_proper_types (items )
402
+
403
+ # NB: ideally we would use a frozenset, but that would require normalizing the
404
+ # order of entries in the simplified union, or updating the test harness to
405
+ # treat Unions as equivalent regardless of item ordering (which is particularly
406
+ # tricky when it comes to all tests using string matching on reveal_type output)
407
+ cache_key = tuple (items )
408
+ # NB: we need to maintain separate caches depending on flags that might impact
409
+ # the results of simplification
410
+ cache = _simplified_union_cache [
411
+ int (keep_erased )
412
+ | int (contract_literals ) << 1
413
+ | int (state .strict_optional ) << 2
414
+ ]
415
+ ret = cache .get (cache_key , None )
416
+ if ret is None :
417
+ # Step 2: remove redundant unions
418
+ simplified_set = _remove_redundant_union_items (items , keep_erased )
376
419
377
- # Step 2: remove redundant unions
378
- simplified_set = _remove_redundant_union_items (items , keep_erased )
420
+ # Step 3: If more than one literal exists in the union, try to simplify
421
+ if contract_literals and sum (isinstance (item , LiteralType ) for item in simplified_set ) > 1 :
422
+ simplified_set = try_contracting_literals_in_union (simplified_set )
379
423
380
- # Step 3: If more than one literal exists in the union, try to simplify
381
- if contract_literals and sum (isinstance (item , LiteralType ) for item in simplified_set ) > 1 :
382
- simplified_set = try_contracting_literals_in_union (simplified_set )
424
+ ret = UnionType .make_union (simplified_set , line , column )
383
425
384
- return UnionType .make_union (simplified_set , line , column )
426
+ # cache simplified value
427
+ cache [cache_key ] = ret
385
428
429
+ return ret
386
430
387
- def _remove_redundant_union_items (items : List [ProperType ], keep_erased : bool ) -> List [ProperType ]:
431
+
432
+ def _remove_redundant_union_items (items : Sequence [ProperType ],
433
+ keep_erased : bool ) -> Sequence [ProperType ]:
388
434
from mypy .subtypes import is_proper_subtype
389
435
436
+ if len (items ) <= 1 :
437
+ return items
438
+
390
439
removed : Set [int ] = set ()
440
+ truthed : Set [int ] = set ()
391
441
seen : Set [Tuple [str , ...]] = set ()
392
442
393
443
# NB: having a separate fast path for Union of Literal and slow path for other things
@@ -397,6 +447,7 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) ->
397
447
for i , item in enumerate (items ):
398
448
if i in removed :
399
449
continue
450
+
400
451
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
401
452
k = simple_literal_value_key (item )
402
453
if k is not None :
@@ -434,20 +485,34 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) ->
434
485
continue
435
486
# actual redundancy checks
436
487
if (
437
- is_redundant_literal_instance (item , tj ) # XXX?
438
- and is_proper_subtype (tj , item , keep_erased_types = keep_erased )
488
+ isinstance (tj , UninhabitedType )
489
+ or (
490
+ (
491
+ not isinstance (item , Instance )
492
+ or item .last_known_value is None
493
+ or (
494
+ isinstance (tj , Instance )
495
+ and tj .last_known_value == item .last_known_value
496
+ )
497
+ )
498
+ and is_proper_subtype (tj , item , keep_erased_types = keep_erased )
499
+ )
439
500
):
440
501
# We found a redundant item in the union.
441
502
removed .add (j )
442
503
cbt = cbt or tj .can_be_true
443
504
cbf = cbf or tj .can_be_false
505
+
444
506
# if deleted subtypes had more general truthiness, use that
445
507
if not item .can_be_true and cbt :
446
- items [ i ] = true_or_false ( item )
508
+ truthed . add ( i )
447
509
elif not item .can_be_false and cbf :
448
- items [ i ] = true_or_false ( item )
510
+ truthed . add ( i )
449
511
450
- return [items [i ] for i in range (len (items )) if i not in removed ]
512
+ if not removed and not truthed :
513
+ return items
514
+ return [true_or_false (items [i ]) if i in truthed else items [i ]
515
+ for i in range (len (items )) if i not in removed ]
451
516
452
517
453
518
def _get_type_special_method_bool_ret_type (t : Type ) -> Optional [Type ]:
@@ -889,17 +954,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
889
954
return False
890
955
891
956
892
- def is_redundant_literal_instance (general : ProperType , specific : ProperType ) -> bool :
893
- if not isinstance (general , Instance ) or general .last_known_value is None :
894
- return True
895
- if isinstance (specific , Instance ) and specific .last_known_value == general .last_known_value :
896
- return True
897
- if isinstance (specific , UninhabitedType ):
898
- return True
899
-
900
- return False
901
-
902
-
903
957
def separate_union_literals (t : UnionType ) -> Tuple [Sequence [LiteralType ], Sequence [Type ]]:
904
958
"""Separate literals from other members in a union type."""
905
959
literal_items = []
0 commit comments