8000 Foundations for non-linear solver and polymorphic application by ilevkivskyi · Pull Request #15287 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Foundations for non-linear solver and polymorphic application #15287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Address CR
  • Loading branch information
ilevkivskyi committed Jun 18, 2023
commit 66b4567aa36ae9d4429003e4106db1284d1178e4
31 changes: 27 additions & 4 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def solve_non_linear(

The whole algorithm consists of five steps:
* Propagate via linear constraints to get all possible constraints for each variable
* Find dependencies between type variables, group them in SCCs, and sor topologically
* Find dependencies between type variables, group them in SCCs, and sort topologically
* Check all SCC are intrinsically linear, we can't solve (express) T <: List[T]
* Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC)
* Solve constraints iteratively starting from leafs, updating targets after each step.
Expand All @@ -112,11 +112,18 @@ def solve_non_linear(
leafs = raw_batches[0]
free_vars = []
for scc in leafs:
# If all constrain targets in this SCC are type variables within the
# same SCC then the only meaningful solution we can express, is that
# each variable is equal to a new free variable. For example if we
# have T <: S, S <: U, we deduce: T = S = U = <free>.
if all(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment here about the purpose of this if statement (I think I figured it out but it wasn't obvious).

isinstance(c.target, TypeVarType) and c.target.id in vars
for tv in scc
for c in cmap[tv]
):
# For convenience with current type application machinery, we randomly
# choose one of the existing type variables in SCC and designate it as free
# instead of defining a new type variable as a common solution.
# TODO: be careful about upper bounds (or values) when introducing free vars.
free_vars.append(sorted(scc, key=lambda x: x.raw_id)[0])

Expand Down Expand Up @@ -146,7 +153,17 @@ def solve_non_linear(
def solve_iteratively(
batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId]
) -> dict[TypeVarId, Type | None]:
"""Solve constraints for type variables sequentially, updating targets after each step."""
"""Solve constraints sequentially, updating constraint targets after each step.

We solve for type variables that appear in `batch`. If a constraint target is not constant
(i.e. constraint looks like T :> F[S, ...]), we substitute solutions found so far in
the target F[S, ...]. This way we can gradually solve for all variables in the batch taking
one solvable variable at a time (i.e. such a variable that has at least one constant bound).

Importantly, variables in free_vars are considered constants, so for example if we have just
one initial constraint T <: List[S], we will have two SCCs {T} and {S}, then we first
designate S as free, and therefore T = List[S] is a valid solution for T.
"""
solutions = {}
relevant_constraints = []
for tv in batch:
Expand Down Expand Up @@ -293,8 +310,14 @@ def transitive_closure(
) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]:
"""Find transitive closure for given constraints on type variables.

Transitive closure gives maximal set of lower/upper bounds for each type variable, such
we cannot deduce any further bounds by chaining other existing bounds.
Transitive closure gives maximal set of lower/upper bounds for each type variable,
such that we cannot deduce any further bounds by chaining other existing bounds.

For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive
closure is given by:
* {} <: T <: {S, U, int}
* {T} <: S <: {U, int}
* {T, S} <: U <: {int}
"""
# TODO: merge propagate_constraints_for() into this function.
# TODO: add secondary constraints here to make the algorithm complete.
Expand Down
21 changes: 21 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,9 @@ dict2 = {"a": C1(), **{x: C2() for x in dict1}}
reveal_type(dict2) # N: Revealed type is "builtins.dict[Any, __main__.B]"
[builtins fixtures/dict.pyi]

-- Type inference for generic decorators applied to generic callables
-- ------------------------------------------------------------------

[case testInferenceAgainstGenericCallable]
# flags: --new-type-inference
from typing import TypeVar, Callable, List
Expand Down Expand Up @@ -2794,6 +2797,12 @@ def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]:
def id(x: U) -> U:
...
reveal_type(dec(id)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an actual use of the decorator to at least some of the decorator-like use cases, for more end-to-end testing? E.g. something like this:

@dec 
def f(...) -> ...: ...  # Generic function
reveal_type(f(...))


@dec
def same(x: U) -> U:
...
reveal_type(same) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]"
reveal_type(same(42)) # N: Revealed type is "builtins.list[builtins.int]"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericCallableGenericReverse]
Expand All @@ -2809,6 +2818,12 @@ def dec(f: Callable[[S], List[T]]) -> Callable[[S], T]:
def id(x: U) -> U:
...
reveal_type(dec(id)) # N: Revealed type is "def [T] (builtins.list[T`2]) -> T`2"

@dec
def same(x: U) -> U:
...
reveal_type(same) # N: Revealed type is "def [T] (builtins.list[T`4]) -> T`4"
reveal_type(same([42])) # N: Revealed type is "builtins.int"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericCallableGenericArg]
Expand All @@ -2824,6 +2839,12 @@ def dec(f: Callable[[S], T]) -> Callable[[S], T]:
def test(x: U) -> List[U]:
...
reveal_type(dec(test)) # N: Revealed type is "def [S] (S`1) -> builtins.list[S`1]"

@dec
def single(x: U) -> List[U]:
...
reveal_type(single) # N: Revealed type is "def [S] (S`3) -> builtins.list[S`3]"
reveal_type(single(42)) # N: Revealed type is "builtins.list[builtins.int]"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericCallableGenericChain]
Expand Down
0