10000 Use polymorphic inference in unification by ilevkivskyi · Pull Request #17348 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Use polymorphic inference in unification #17348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 10, 2024
Next Next commit
Use polymorphic inference in unification
  • Loading branch information
ilevkivskyi committed Jun 8, 2024
commit fad1094afa288d65aec7664c5adb733191f2b473
40 changes: 33 additions & 7 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ def __init__(
# new uses of this, as this may cause leaking `UnboundType`s to type checking.
self.allow_unbound_tvars = False

# Used to pass information about current overload index to visit_func_def().
self.current_overload_item: int | None = None

# mypyc doesn't properly handle implementing an abstractproperty
# with a regular attribute so we make them properties
@property
Expand Down Expand Up @@ -869,6 +872,15 @@ def visit_func_def(self, defn: FuncDef) -> None:
with self.scope.function_scope(defn):
self.analyze_func_def(defn)

def function_fullname(self, fullname: str) -> str:
if self.current_overload_item is None:
return fullname
if self.current_overload_item < 0:
suffix = "impl"
else:
suffix = str(self.current_overload_item)
return f"{fullname}#{suffix}"

def analyze_func_def(self, defn: FuncDef) -> None:
if self.push_type_args(defn.type_args, defn) is None:
self.defer(defn)
Expand All @@ -895,7 +907,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
self.prepare_method_signature(defn, self.type, has_self_type)

# Analyze function signature
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname =self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
if defn.type:
self.check_classvar_in_signature(defn.type)
assert isinstance(defn.type, CallableType)
Expand All @@ -904,7 +917,7 @@ def analyze_func_def(self, defn: FuncDef) -> None:
analyzer = self.type_analyzer()
tag = self.track_incomplete_refs()
result = analyzer.visit_callable_type(
defn.type, nested=False, namespace=defn.fullname
defn.type, nested=False, namespace=fullname
)
# Don't store not ready types (including placeholders).
if self.found_incomplete_ref(tag) or has_placeholder(result):
Expand Down Expand Up @@ -1117,7 +1130,8 @@ def update_function_type_variables(self, fun_type: CallableType, defn: FuncItem)
if defn is generic. Return True, if the signature contains typing.Self
type, or False otherwise.
"""
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
a = self.type_analyzer()
fun_type.variables, has_self_type = a.bind_function_type_variables(fun_type, defn)
if has_self_type and self.type is not None:
Expand Down Expand Up @@ -1175,6 +1189,14 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
with self.scope.function_scope(defn):
self.analyze_overloaded_func_def(defn)

@contextmanager
def overload_item_set(self, item: int) -> Iterator[None]:
self.current_overload_item = item
try:
yield
finally:
self.current_overload_item = None

def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
# OverloadedFuncDef refers to any legitimate situation where you have
# more than one declaration for the same function in a row. This occurs
Expand All @@ -1187,7 +1209,8 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:

first_item = defn.items[0]
first_item.is_overload = True
first_item.accept(self)
with self.overload_item_set(0):
first_item.accept(self)

if isinstance(first_item, Decorator) and first_item.func.is_property:
# This is a property.
Expand Down Expand Up @@ -1272,7 +1295,8 @@ def analyze_overload_sigs_and_impl(
if i != 0:
# Assume that the first item was already visited
item.is_overload = True
item.accept(self)
with self.overload_item_set(i if i < len(defn.items) - 1 else -1):
item.accept(self)
# TODO: support decorated overloaded functions properly
if isinstance(item, Decorator):
callable = function_type(item.func, self.named_type("builtins.function"))
Expand Down Expand Up @@ -1444,15 +1468,17 @@ def add_function_to_symbol_table(self, func: FuncDef | OverloadedFuncDef) -> Non
self.add_symbol(func.name, func, func)

def analyze_arg_initializers(self, defn: FuncItem) -> None:
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
# Analyze default arguments
for arg in defn.arguments:
if arg.initializer:
arg.initializer.accept(self)

def analyze_function_body(self, defn: FuncItem) -> None:
is_method = self.is_class_scope()
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
# Bind the type variables again to visit the body.
if defn.type:
a = self.type_analyzer()
Expand Down
9 changes: 8 additions & 1 deletion mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
is a linear constraint. This is however not true in presence of union types, for example
T :> Union[S, int] vs S <: T. Trying to solve such constraints would be detected ambiguous
as (T, S) form a non-linear SCC. However, simply removing the linear part results in a valid
solution T = Union[S, int], S = <free>.
solution T = Union[S, int], S = <free>. A similar scenario is when we get T <: Union[T, int],
such constraints carry no information, and will equally confuse linearity check.

TODO: a cleaner solution may be to avoid inferring such constraints in first place, but
this would require passing around a flag through all infer_constraints() calls.
Expand All @@ -525,7 +526,13 @@ def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
if isinstance(p_target, UnionType):
for item in p_target.items:
if isinstance(item, TypeVarType):
if item == c.origin_type_var and c.op == SUBTYPE_OF:
reverse_union_cs.add(c)
continue
# These two forms are semantically identical, but are different from
# the point of view of Constraint.__eq__().
reverse_union_cs.add(Constraint(item, neg_op(c.op), c.origin_type_var))
reverse_union_cs.add(Constraint(c.origin_type_var, c.op, item))
return [c for c in cs if c not in reverse_union_cs]


Expand Down
4 changes: 3 additions & 1 deletion mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,7 +1879,9 @@ def unify_generic_callable(
constraints = [
c for c in constraints if not isinstance(get_proper_type(c.target), NoneType)
]
inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints)
inferred_vars, _ = mypy.solve.solve_constraints(
type.variables, constraints, allow_polymorphic=True
)
if None in inferred_vars:
return None
non_none_inferred_vars = cast(List[Type], inferred_vars)
Expand Down
18 changes: 18 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -3442,3 +3442,21 @@ reveal_type(dec(g)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[b
h: Callable[[Unpack[Us]], Foo[int]]
reveal_type(dec(h)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[builtins.int]"
[builtins fixtures/list.pyi]

[case testHigherOrderGenericPartial]
from typing import TypeVar, Callable

T = TypeVar("T")
S = TypeVar("S")
U = TypeVar("U")
def apply(f: Callable[[T], S], x: T) -> S: ...
def id(x: U) -> U: ...

A1 = TypeVar("A1")
A2 = TypeVar("A2")
R = TypeVar("R")
def fake_partial(fun: Callable[[A1, A2], R], arg: A1) -> Callable[[A2], R]: ...

f_pid = fake_partial(apply, id)
reveal_type(f_pid) # N: Revealed type is "def [A2] (A2`2) -> A2`2"
reveal_type(f_pid(1)) # N: Revealed type is "builtins.int"
22 changes: 8 additions & 14 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -371,20 +371,18 @@ def foo(t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and
def foo(t: T, s: T) -> str: ...
def foo(t, s): pass

# TODO: examples below are technically unsafe.
class Wrapper(Generic[T]):
@overload
def foo(self, t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def foo(self, t: List[T], s: T) -> int: ...
@overload
def foo(self, t: T, s: T) -> str: ...
def foo(self, t, s): pass

class Dummy(Generic[T]): pass

# Same root issue: why does the additional constraint bound T <: T
# cause the constraint solver to not infer T = object like it did in the
# first example?
@overload
def bar(d: Dummy[T], t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def bar(d: Dummy[T], t: List[T], s: T) -> int: ...
@overload
def bar(d: Dummy[T], t: T, s: T) -> str: ...
def bar(d: Dummy[T], t, s): pass
Expand Down Expand Up @@ -2865,11 +2863,8 @@ class Wrapper(Generic[T]):
def f(self, x: T) -> T: ...
def f(self, x): ...

# TODO: This shouldn't trigger an error message?
# Related to testTypeCheckOverloadImplementationTypeVarDifferingUsage2?
# See https://github.com/python/mypy/issues/5510
@overload
def g(self, x: int) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def g(self, x: int) -> int: ...
@overload
def g(self, x: T) -> T: ...
def g(self, x): ...
Expand All @@ -2892,16 +2887,15 @@ class Wrapper(Generic[T]):
def f2(self, x: List[T]) -> List[T]: ...
def f2(self, x): ...

# TODO: This shouldn't trigger an error message?
# See https://github.com/python/mypy/issues/5510
@overload
def g1(self, x: List[int]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def g1(self, x: List[int]) -> int: ...
@overload
def g1(self, x: List[T]) -> T: ...
def g1(self, x): ...

# TODO: this is technically unsafe.
@overload
def g2(self, x: List[int]) -> List[int]: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def g2(self, x: List[int]) -> List[int]: ...
@overload
def g2(self, x: List[T]) -> List[T]: ...
def g2(self, x): ...
Expand Down Expand Up @@ -6483,7 +6477,7 @@ P = ParamSpec("P")
R = TypeVar("R")

@overload
def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ...
@overload
def func(x: Callable[P, R]) -> Callable[Concatenate[str, P], R]: ...
def func(x: Callable[..., R]) -> Callable[..., R]: ...
Expand Down
0