8000 make_simplified_union: add caching and reduce allocations · python/mypy@a64490e · GitHub
[go: up one dir, main page]

Skip to content

Commit a64490e

Browse files
committed
make_simplified_union: add caching and reduce allocations
make_simplified_union is used in a lot of places and therefore accounts for a significant share to typechecking time. Based on sample metrics gathered from a large real-world codebase we can see that: 1. the majority of inputs are already as simple as they're going to get, which means we can avoid allocation extra lists and return the input unchanged 2. most of the cost of `make_simplified_union` comes from `is_proper_subtype` 3. `is_proper_subtype` has some caching going on under the hood but it only applies to `Instance`, and cache hit rate is low in this particular case because, as per 1) above, items are in fact rarely subtypes of each other To address 1, refactor `make_simplified_union` with an optimistic fast path that avoid unnecessary allocations. To address 2 & 3, introduce a cache to record the result of union simplification. These changes are observed to yield significant improvements in a real-world codebase: a roughly 10-20% overall speedup, with make_simplified_union/is_proper_subtype no longer showing up as hotspots in the py-spy profile. For #12526
1 parent d1c0616 commit a64490e

File tree

1 file changed

+87
-33
lines changed

1 file changed

+87
-33
lines changed

mypy/typeops.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,41 @@ def is_simple_literal(t: ProperType) -> bool:
336336
return False
337337

338338

339+
def _get_flattened_proper_types(items: Sequence[Type]) -> Sequence[ProperType]:
340+
"""Similar to types.get_proper_types, with flattening of UnionType
341+
342+
Optimized to avoid allocating a new list whenever possible"""
343+
i: int = 0
344+
n: int = len(items)
345+
346+
# optimistic fast path
347+
while i < n:
348+
t = items[i]
349+
pt = get_proper_type(t)
350+
if id(t) != id(pt) or isinstance(pt, UnionType):
351+
# we need to allocate, switch to slow path
352+
break
353+
i += 1
354+
355+
# optimistic fast path reached end of input, no need to allocate
356+
if i == n:
357+
return cast(Sequence[ProperType], items)
358+
359+
all_items = list(cast(Sequence[ProperType], items[0:i]))
360+
361+
while i < n:
362+
pt = get_proper_type(items[i])
363+
if isinstance(pt, UnionType):
364+
all_items.extend(_get_flattened_proper_types(pt.items))
365+
else:
366+
all_items.append(pt)
367+
i += 1
368+
return all_items
369+
370+
371+
_simplified_union_cache: List[Dict[Tuple[ProperType, ...], ProperType]] = [{} for _ in range(8)]
372+
373+
339374
def make_simplified_union(items: Sequence[Type],
340375
line: int = -1, column: int = -1,
341376
*, keep_erased: bool = False,
@@ -362,32 +397,47 @@ def make_simplified_union(items: Sequence[Type],
362397
back into a sum type. Set it to False when called by try_expanding_sum_type_
363398
to_union().
364399
"""
365-
items = get_proper_types(items)
366-
367400
# Step 1: expand all nested unions
368-
while any(isinstance(typ, UnionType) for typ in items):
369-
all_items: List[ProperType] = []
370-
for typ in items:
371-
if isinstance(typ, UnionType):
372-
all_items.extend(get_proper_types(typ.items))
373-
else:
374-
all_items.append(typ)
375-
items = all_items
401+
items = _get_flattened_proper_types(items)
402+
403+
# NB: ideally we would use a frozenset, but that would require normalizing the
404+
# order of entries in the simplified union, or updating the test harness to
405+
# treat Unions as equivalent regardless of item ordering (which is particularly
406+
# tricky when it comes to all tests using string matching on reveal_type output)
407+
cache_key = tuple(items)
408+
# NB: we need to maintain separate caches depending on flags that might impact
409+
# the results of simplification
410+
cache = _simplified_union_cache[
411+
int(keep_erased)
412+
| int(contract_literals) << 1
413+
| int(state.strict_optional) << 2
414+
]
415+
ret = cache.get(cache_key, None)
416+
if ret is None:
417+
# Step 2: remove redundant unions
418+
simplified_set = _remove_redundant_union_items(items, keep_erased)
376419

377-
# Step 2: remove redundant unions
378-
simplified_set = _remove_redundant_union_items(items, keep_erased)
420+
# Step 3: If more than one literal exists in the union, try to simplify
421+
if contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1:
422+
simplified_set = try_contracting_literals_in_union(simplified_set)
379423

380-
# Step 3: If more than one literal exists in the union, try to simplify
381-
if contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1:
382-
simplified_set = try_contracting_literals_in_union(simplified_set)
424+
ret = UnionType.make_union(simplified_set, line, column)
383425

384-
return UnionType.make_union(simplified_set, line, column)
426+
# cache simplified value
427+
cache[cache_key] = ret
385428

429+
return ret
386430

387-
def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]:
431+
432+
def _remove_redundant_union_items(items: Sequence[ProperType],
433+
keep_erased: bool) -> Sequence[ProperType]:
388434
from mypy.subtypes import is_proper_subtype
389435

436+
if len(items) <= 1:
437+
return items
438+
390439
removed: Set[int] = set()
440+
truthed: Set[int] = set()
391441
seen: Set[Tuple[str, ...]] = set()
392442

393443
# NB: having a separate fast path for Union of Literal and slow path for other things
@@ -397,6 +447,7 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) ->
397447
for i, item in enumerate(items):
398448
if i in removed:
399449
continue
450+
400451
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
401452
k = simple_literal_value_key(item)
402453
if k is not None:
@@ -434,20 +485,34 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) ->
434485
continue
435486
# actual redundancy checks
436487
if (
437-
is_redundant_literal_instance(item, tj) # XXX?
438-
and is_proper_subtype(tj, item, keep_erased_types=keep_erased)
488+
isinstance(tj, UninhabitedType)
489+
or (
490+
(
491+
not isinstance(item, Instance)
492+
or item.last_known_value is None
493+
or (
494+
isinstance(tj, Instance)
495+
and tj.last_known_value == item.last_known_value
496+
)
497+
)
498+
and is_proper_subtype(tj, item, keep_erased_types=keep_erased)
499+
)
439500
):
440501
# We found a redundant item in the union.
441502
removed.add(j)
442503
cbt = cbt or tj.can_be_true
443504
cbf = cbf or tj.can_be_false
505+
444506
# if deleted subtypes had more general truthiness, use that
445507
if not item.can_be_true and cbt:
446-
items[i] = true_or_false(item)
508+
truthed.add(i)
447509
elif not item.can_be_false and cbf:
448-
items[i] = true_or_false(item)
510+
truthed.add(i)
449511

450-
return [items[i] for i in range(len(items)) if i not in removed]
512+
if not removed and not truthed:
513+
return items
514+
return [true_or_false(items[i]) if i in truthed else items[i]
515+
for i in range(len(items)) if i not in removed]
451516

452517

453518
def _get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]:
@@ -889,17 +954,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
889954
return False
890955

891956

892-
def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool:
893-
if not isinstance(general, Instance) or general.last_known_value is None:
894-
return True
895-
if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value:
896-
return True
897-
if isinstance(specific, UninhabitedType):
898-
return True
899-
900-
return False
901-
902-
903957
def separate_union_literals(t: UnionType) -> Tuple[Sequence[LiteralType], Sequence[Type]]:
904958
"""Separate literals from other members in a union type."""
905959
literal_items = []

0 commit comments

Comments
 (0)
0