8000 Foundations for non-linear solver and polymorphic application by ilevkivskyi · Pull Request #15287 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Foundations for non-linear solver and polymorphic application #15287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix bugs
  • Loading branch information
ilevkivskyi committed Jun 5, 2023
commit 0bc41b06ed8878e2dcfbed194c82f87b3551f098
212 changes: 142 additions & 70 deletions mypy/solve.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import Iterable

from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op
from mypy.expandtype import expand_type
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
Expand Down Expand Up @@ -55,7 +57,13 @@ def solve_constraints(
if allow_polymorphic:
solutions = solve_non_linear(vars, constraints, cmap)
else:
solutions = solve_iteratively([vars], cmap, vars)
solutions = {}
for tv, cs in cmap.items():
if not cs:
continue
lowers = [c.target for c in cs if c.op == SUPERTYPE_OF]
uppers = [c.target for c in cs if c.op == SUBTYPE_OF]
solutions[tv] = solve_one(lowers, uppers, [])

res: list[Type | None] = []
for v in vars:
Expand All @@ -81,14 +89,12 @@ def solve_non_linear(
The whole algorithm consists of five steps:
* Propagate via linear constraints to get all possible constraints for each variable
* Find dependencies between type variables, group them in SCCs, and sor topologically
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: 'sor'

* Check all SCC are intrinsically linear, it is impossible to solve T <: List[T]
* Check all SCC are intrinsically linear, we can't solve (express) T <: List[T]
* Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC)
* Solve constraints iteratively starting from leafs, updating targets after each step.
"""
extra_constraints = []
for tvar in vars:
# TODO: support iteratively inferring secondary constraints like
# Sequence[T] <: S <: Sequence[U] => T <: U
extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap))
extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap))
constraints += remove_dups(extra_constraints)
Expand All @@ -111,6 +117,7 @@ def solve_non_linear(
for tv in scc
for c in cmap[tv]
):
# TODO: be careful about upper bounds (or values) when introducing free vars.
free_vars.append(next(tv for tv in scc))

# Flatten the SCCs that are independent, we can solve them together,
Expand All @@ -122,9 +129,13 @@ def solve_non_linear(
next_bc.extend(list(scc))
batches.append(next_bc)

solutions = solve_iteratively(batches, cmap, free_vars)
solutions: dict[TypeVarId, Type | None] = {}
for flat_batch in batches:
solutions.update(solve_iteratively(flat_batch, cmap, free_vars))
# We remove the solutions like T = T for free variables. This will indicate
# to the apply function, that they should not be touched.
# TODO: return list of free type variables explicitly, this logic is fragile
# (but if we do, we need to be careful everything works in incremental modes).
for tv in free_vars:
if tv in solutions:
del solutions[tv]
Expand All @@ -133,83 +144,97 @@ def solve_non_linear(


def solve_iteratively(
batches: list[list[TypeVarId]],
cmap: dict[TypeVarId, list[Constraint]],
free_vars: list[TypeVarId],
batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId]
) -> dict[TypeVarId, Type | None]:
"""Solve constraints for type variables sequentially, updating targets after each step."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't clear to me what 'targets' means here without studying the implementation. Clarify?

Document that we are solving for batch (not free_vars), and document free_vars.

It could be useful to eventually have some unit tests for this (no need to do it in this PR, but before the new type inference logic is enabled by default).

solutions: dict[TypeVarId, Type | None] = {}
for batch in batches:
tmap = solve_once(batch, cmap, free_vars)
if not tmap:
solutions = {}
relevant_constraints = []
for tv in batch:
relevant_constraints.extend(cmap.get(tv, []))
lowers, uppers = transitive_closure(batch, relevant_constraints)
s_batch = set(batch)
not_allowed_vars = [v for v in batch if v not in free_vars]
while s_batch:
for tv in s_batch:
if any(not get_vars(l, not_allowed_vars) for l in lowers[tv]) or any(
not get_vars(u, not_allowed_vars) for u in uppers[tv]
):
solvable_tv = tv
break
else:
break
# Solve each solvable type variable separately.
s_batch.remove(solvable_tv)
result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars)
solutions[solvable_tv] = result
if result is None:
# TODO: support backtracking lower/upper bound choices
# (will require switching this function from iterative to recursive).
continue
# Update the (transitive) constraints if there is a solution.
subs = {solvable_tv: result}
lowers = {tv: {expand_type(l, subs) for l in lowers[tv]} for tv in lowers}
uppers = {tv: {expand_type(u, subs) for u in uppers[tv]} for tv in uppers}
for v in cmap:
for c in cmap[v]:
c.target = expand_type(
c.target, {k: v for (k, v) in tmap.items() if v is not None}
)
# TODO: support backtracking lower/upper bound choices
# (will require switching this function from iterative to recursive).
solutions.update(tmap)
c.target = expand_type(c.target, subs)
return solutions


def solve_once(
vars: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId]
) -> dict[TypeVarId, Type | None]:
def solve_one(
lowers: Iterable[Type], uppers: Iterable[Type], not_allowed_vars: list[TypeVarId]
) -> Type | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly some unit tests would be nice at some point.

"""Solve constraints by finding by using meets of upper bounds, and joins of lower bounds."""
res: dict[TypeVarId, Type | None] = {}
# Solve each type variable separately.
for tvar in vars:
bottom: Type | None = None
top: Type | None = None
candidate: Type | None = None

# Process each constraint separately, and calculate the lower and upper
# bounds based on constraints. Note that we assume that the constraint
# targets do not have constraint references.
for c in cmap.get(tvar, []):
# There may be multiple steps needed to solve all vars within a
# (linear) SCC. We ignore targets pointing to not yet solved vars.
if get_vars(c.target, [v for v in vars if v not in free_vars]):
continue
if c.op == SUPERTYPE_OF:
if bottom is None:
bottom = c.target
else:
if type_state.infer_unions:
# This deviates from the general mypy semantics because
# recursive types are union-heavy in 95% of cases.
bottom = UnionType.make_union([bottom, c.target])
else:
bottom = join_types(bottom, c.target)
else:
if top is None:
top = c.target
else:
top = meet_types(top, c.target)

p_top = get_proper_type(top)
p_bottom = get_proper_type(bottom)
if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType):
source_any = top if isinstance(p_top, AnyType) else bottom
assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType)
res[tvar] = AnyType(TypeOfAny.from_another_any, source_any=source_any)
bottom: Type | None = None
top: Type | None = None
candidate: Type | None = None

# Process each bound separately, and calculate the lower and upper
# bounds based on constraints. Note that we assume that the constraint
# targets do not have constraint references.
for target in lowers:
# There may be multiple steps needed to solve all vars within a
# (linear) SCC. We ignore targets pointing to not yet solved vars.
if get_vars(target, not_allowed_vars):
continue
elif bottom is None:
if top:
candidate = top
if bottom is None:
bottom = target
else:
if type_state.infer_unions:
# This deviates from the general mypy semantics because
# recursive types are union-heavy in 95% of cases.
bottom = UnionType.make_union([bottom, target])
else:
# No constraints for type variable
continue
elif top is None:
candidate = bottom
elif is_subtype(bottom, top):
candidate = bottom
bottom = join_types(bottom, target)

for target in uppers:
# Same as above.
if get_vars(target, not_allowed_vars):
continue
if top is None:
top = target
else:
candidate = None
res[tvar] = candidate
return res
top = meet_types(top, target)

p_top = get_proper_type(top)
p_bottom = get_proper_type(bottom)
if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType):
source_any = top if isinstance(p_top, AnyType) else bottom
assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType)
return AnyType(TypeOfAny.from_another_any, source_any=source_any)
elif bottom is None:
if top:
candidate = top
else:
# No constraints for type variable
return None
elif top is None:
candidate = bottom
elif is_subtype(bottom, top):
candidate = bottom
else:
candidate = None
return candidate


def normalize_constraints(
Expand Down Expand Up @@ -263,6 +288,53 @@ def propagate_constraints_for(
return extra_constraints


def transitive_closure(
tvars: list[TypeVarId], constraints: list[Constraint]
) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]:
"""Find transitive closure for given constraints on type variables.

Transitive closure gives maximal set of lower/upper bounds for each type variable, such
we cannot deduce any further bounds by chaining other existing bounds.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give an example or two that illustrates what this means in the docstring? Again, it would be good to have some unit tests to validate the logic and avoid regressions in the future.

"""
# TODO: merge propagate_constraints_for() into this function.
# TODO: add secondary constraints here to make the algorithm complete.
uppers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars}
lowers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars}
graph: set[tuple[TypeVarId, TypeVarId]] = set()

# Prime the closure with the initial trivial values.
for c in constraints:
if isinstance(c.target, TypeVarType) and c.target.id in tvars:
if c.op == SUBTYPE_OF:
graph.add((c.type_var, c.target.id))
else:
graph.add((c.target.id, c.type_var))
if c.op == SUBTYPE_OF:
uppers[c.type_var].add(c.target)
else:
lowers[c.type_var].add(c.target)

# At this stage we know that constant bounds have been propagated already, so we
# only need to propagate linear constraints.
for c in constraints:
if isinstance(c.target, TypeVarType) and c.target.id in tvars:
if c.op == SUBTYPE_OF:
lower, upper = c.type_var, c.target.id
else:
lower, upper = c.target.id, c.type_var
extras = {
(l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph
}
graph |= extras
for u in tvars:
if (upper, u) in graph:
lowers[u] |= lowers[lower]
for l in tvars:
if (l, lower) in graph:
uppers[l] |= uppers[upper]
return lowers, uppers


def compute_dependencies(
cmap: dict[TypeVarId, list[Constraint]]
) -> dict[TypeVarId, list[TypeVarId F438 ]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm repeating myself, but again unit tests would be nice :-)

Expand Down
43 changes: 37 additions & 6 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -2893,14 +2893,45 @@ reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]
# flags: --strict-optional
from typing import TypeVar, Protocol, Generic, Optional

_T = TypeVar('_T')
T = TypeVar('T')

class _F(Protocol[_T]):
def __call__(self, __x: _T) -> _T: ...
class F(Protocol[T]):
def __call__(self, __x: T) -> T: ...

def lift(f: _F[_T]) -> _F[Optional[_T]]: ...
def g(x: _T) -> _T:
def lift(f: F[T]) -> F[Optional[T]]: ...
def g(x: T) -> T:
return x

reveal_type(lift(g)) # N: Revealed type is "def [_T] (Union[_T`1, None]) -> Union[_T`1, None]"
reveal_type(lift(g)) # N: Revealed type is "def [T] (Union[T`1, None]) -> Union[T`1, None]"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericSplitOrder]
# flags: --strict-optional
from typing import TypeVar, Callable, List

S = TypeVar('S')
T = TypeVar('T')
U = TypeVar('U')

def dec(f: Callable[[T], S], g: Callable[[T], int]) -> Callable[[T], List[S]]: ...
def id(x: U) -> U:
...

reveal_type(dec(id, id)) # N: Revealed type is "def (builtins.int) -> builtins.list[builtins.int]"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericSplitOrderGeneric]
# flags: --strict-optional
from typing import TypeVar, Callable, Tuple

S = TypeVar('S')
T = TypeVar('T')
U = TypeVar('U')
V = TypeVar('V')

def dec(f: Callable[[T], S], g: Callable[[T], U]) -> Callable[[T], Tuple[S, U]]: ...
def id(x: V) -> V:
...

reveal_type(dec(id, id)) # N: Revealed type is "def [S] (S`2) -> Tuple[S`2, S`2]"
[builtins fixtures/tuple.pyi]
0