-
-
Notifications
You must be signed in to change notification settings - Fork 3k
Foundations for non-linear solver and polymorphic application #15287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
08a8815
eb3a1e1
52209c9
a64183d
9191761
ba0f252
8eb04a9
ec8b695
d5eb5fa
4c41c67
163720c
0bc41b0
fefe27e
4aca3ba
42cc4cf
d0a3d0d
47db859
96d0f39
66b4567
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
|
||
from __future__ import annotations | ||
|
||
from typing import Iterable | ||
|
||
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op | ||
from mypy.expandtype import expand_type | ||
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort | ||
|
@@ -55,7 +57,13 @@ def solve_constraints( | |
if allow_polymorphic: | ||
solutions = solve_non_linear(vars, constraints, cmap) | ||
else: | ||
solutions = solve_iteratively([vars], cmap, vars) | ||
solutions = {} | ||
for tv, cs in cmap.items(): | ||
if not cs: | ||
continue | ||
lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] | ||
uppers = [c.target for c in cs if c.op == SUBTYPE_OF] | ||
solutions[tv] = solve_one(lowers, uppers, []) | ||
|
||
res: list[Type | None] = [] | ||
for v in vars: | ||
|
@@ -81,14 +89,12 @@ def solve_non_linear( | |
The whole algorithm consists of five steps: | ||
* Propagate via linear constraints to get all possible constraints for each variable | ||
* Find dependencies between type variables, group them in SCCs, and sor topologically | ||
* Check all SCC are intrinsically linear, it is impossible to solve T <: List[T] | ||
* Check all SCC are intrinsically linear, we can't solve (express) T <: List[T] | ||
* Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) | ||
* Solve constraints iteratively starting from leafs, updating targets after each step. | ||
""" | ||
extra_constraints = [] | ||
for tvar in vars: | ||
# TODO: support iteratively inferring secondary constraints like | ||
# Sequence[T] <: S <: Sequence[U] => T <: U | ||
extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap)) | ||
extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap)) | ||
constraints += remove_dups(extra_constraints) | ||
|
@@ -111,6 +117,7 @@ def solve_non_linear( | |
for tv in scc | ||
for c in cmap[tv] | ||
): | ||
# TODO: be careful about upper bounds (or values) when introducing free vars. | ||
free_vars.append(next(tv for tv in scc)) | ||
|
||
# Flatten the SCCs that are independent, we can solve them together, | ||
|
@@ -122,9 +129,13 @@ def solve_non_linear( | |
next_bc.extend(list(scc)) | ||
batches.append(next_bc) | ||
|
||
solutions = solve_iteratively(batches, cmap, free_vars) | ||
solutions: dict[TypeVarId, Type | None] = {} | ||
for flat_batch in batches: | ||
solutions.update(solve_iteratively(flat_batch, cmap, free_vars)) | ||
# We remove the solutions like T = T for free variables. This will indicate | ||
# to the apply function, that they should not be touched. | ||
# TODO: return list of free type variables explicitly, this logic is fragile | ||
# (but if we do, we need to be careful everything works in incremental modes). | ||
for tv in free_vars: | ||
if tv in solutions: | ||
del solutions[tv] | ||
|
@@ -133,83 +144,97 @@ def solve_non_linear( | |
|
||
|
||
def solve_iteratively( | ||
batches: list[list[TypeVarId]], | ||
cmap: dict[TypeVarId, list[Constraint]], | ||
free_vars: list[TypeVarId], | ||
batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] | ||
) -> dict[TypeVarId, Type | None]: | ||
"""Solve constraints for type variables sequentially, updating targets after each step.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It wasn't clear to me what 'targets' means here without studying the implementation. Clarify? Document that we are solving for batch (not free_vars), and document free_vars. It could be useful to eventually have some unit tests for this (no need to do it in this PR, but before the new type inference logic is enabled by default). |
||
solutions: dict[TypeVarId, Type | None] = {} | ||
for batch in batches: | ||
tmap = solve_once(batch, cmap, free_vars) | ||
if not tmap: | ||
solutions = {} | ||
relevant_constraints = [] | ||
for tv in batch: | ||
relevant_constraints.extend(cmap.get(tv, [])) | ||
lowers, uppers = transitive_closure(batch, relevant_constraints) | ||
s_batch = set(batch) | ||
not_allowed_vars = [v for v in batch if v not in free_vars] | ||
while s_batch: | ||
for tv in s_batch: | ||
if any(not get_vars(l, not_allowed_vars) for l in lowers[tv]) or any( | ||
not get_vars(u, not_allowed_vars) for u in uppers[tv] | ||
): | ||
solvable_tv = tv | ||
break | ||
else: | ||
break | ||
# Solve each solvable type variable separately. | ||
s_batch.remove(solvable_tv) | ||
result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars) | ||
solutions[solvable_tv] = result | ||
if result is None: | ||
# TODO: support backtracking lower/upper bound choices | ||
# (will require switching this function from iterative to recursive). | ||
continue | ||
# Update the (transitive) constraints if there is a solution. | ||
subs = {solvable_tv: result} | ||
lowers = {tv: {expand_type(l, subs) for l in lowers[tv]} for tv in lowers} | ||
uppers = {tv: {expand_type(u, subs) for u in uppers[tv]} for tv in uppers} | ||
for v in cmap: | ||
for c in cmap[v]: | ||
c.target = expand_type( | ||
c.target, {k: v for (k, v) in tmap.items() if v is not None} | ||
) | ||
# TODO: support backtracking lower/upper bound choices | ||
# (will require switching this function from iterative to recursive). | ||
solutions.update(tmap) | ||
c.target = expand_type(c.target, subs) | ||
return solutions | ||
|
||
|
||
def solve_once( | ||
vars: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] | ||
) -> dict[TypeVarId, Type | None]: | ||
def solve_one( | ||
lowers: Iterable[Type], uppers: Iterable[Type], not_allowed_vars: list[TypeVarId] | ||
) -> Type | None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly some unit tests would be nice at some point. |
||
"""Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" | ||
res: dict[TypeVarId, Type | None] = {} | ||
# Solve each type variable separately. | ||
for tvar in vars: | ||
bottom: Type | None = None | ||
top: Type | None = None | ||
candidate: Type | None = None | ||
|
||
# Process each constraint separately, and calculate the lower and upper | ||
# bounds based on constraints. Note that we assume that the constraint | ||
# targets do not have constraint references. | ||
for c in cmap.get(tvar, []): | ||
# There may be multiple steps needed to solve all vars within a | ||
# (linear) SCC. We ignore targets pointing to not yet solved vars. | ||
if get_vars(c.target, [v for v in vars if v not in free_vars]): | ||
continue | ||
if c.op == SUPERTYPE_OF: | ||
if bottom is None: | ||
bottom = c.target | ||
else: | ||
if type_state.infer_unions: | ||
# This deviates from the general mypy semantics because | ||
# recursive types are union-heavy in 95% of cases. | ||
bottom = UnionType.make_union([bottom, c.target]) | ||
else: | ||
bottom = join_types(bottom, c.target) | ||
else: | ||
if top is None: | ||
top = c.target | ||
else: | ||
top = meet_types(top, c.target) | ||
|
||
p_top = get_proper_type(top) | ||
p_bottom = get_proper_type(bottom) | ||
if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType): | ||
source_any = top if isinstance(p_top, AnyType) else bottom | ||
assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType) | ||
res[tvar] = AnyType(TypeOfAny.from_another_any, source_any=source_any) | ||
bottom: Type | None = None | ||
top: Type | None = None | ||
candidate: Type | None = None | ||
|
||
# Process each bound separately, and calculate the lower and upper | ||
# bounds based on constraints. Note that we assume that the constraint | ||
# targets do not have constraint references. | ||
for target in lowers: | ||
# There may be multiple steps needed to solve all vars within a | ||
# (linear) SCC. We ignore targets pointing to not yet solved vars. | ||
if get_vars(target, not_allowed_vars): | ||
continue | ||
elif bottom is None: | ||
if top: | ||
candidate = top | ||
if bottom is None: | ||
bottom = target | ||
else: | ||
if type_state.infer_unions: | ||
# This deviates from the general mypy semantics because | ||
# recursive types are union-heavy in 95% of cases. | ||
bottom = UnionType.make_union([bottom, target]) | ||
else: | ||
# No constraints for type variable | ||
continue | ||
elif top is None: | ||
candidate = bottom | ||
elif is_subtype(bottom, top): | ||
candidate = bottom | ||
bottom = join_types(bottom, target) | ||
|
||
for target in uppers: | ||
# Same as above. | ||
if get_vars(target, not_allowed_vars): | ||
continue | ||
if top is None: | ||
top = target | ||
else: | ||
candidate = None | ||
res[tvar] = candidate | ||
return res | ||
top = meet_types(top, target) | ||
|
||
p_top = get_proper_type(top) | ||
p_bottom = get_proper_type(bottom) | ||
if isinstance(p_top, AnyType) or isinstance(p_bottom, AnyType): | ||
source_any = top if isinstance(p_top, AnyType) else bottom | ||
assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType) | ||
return AnyType(TypeOfAny.from_another_any, source_any=source_any) | ||
elif bottom is None: | ||
if top: | ||
candidate = top | ||
else: | ||
# No constraints for type variable | ||
return None | ||
elif top is None: | ||
candidate = bottom | ||
elif is_subtype(bottom, top): | ||
candidate = bottom | ||
else: | ||
candidate = None | ||
return candidate | ||
|
||
|
||
def normalize_constraints( | ||
|
@@ -263,6 +288,53 @@ def propagate_constraints_for( | |
return extra_constraints | ||
|
||
|
||
def transitive_closure( | ||
tvars: list[TypeVarId], constraints: list[Constraint] | ||
) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]: | ||
"""Find transitive closure for given constraints on type variables. | ||
|
||
Transitive closure gives maximal set of lower/upper bounds for each type variable, such | ||
we cannot deduce any further bounds by chaining other existing bounds. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Give an example or two that illustrates what this means in the docstring? Again, it would be good to have some unit tests to validate the logic and avoid regressions in the future. |
||
""" | ||
# TODO: merge propagate_constraints_for() into this function. | ||
# TODO: add secondary constraints here to make the algorithm complete. | ||
uppers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} | ||
lowers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} | ||
graph: set[tuple[TypeVarId, TypeVarId]] = set() | ||
|
||
# Prime the closure with the initial trivial values. | ||
for c in constraints: | ||
if isinstance(c.target, TypeVarType) and c.target.id in tvars: | ||
if c.op == SUBTYPE_OF: | ||
graph.add((c.type_var, c.target.id)) | ||
else: | ||
graph.add((c.target.id, c.type_var)) | ||
if c.op == SUBTYPE_OF: | ||
uppers[c.type_var].add(c.target) | ||
else: | ||
lowers[c.type_var].add(c.target) | ||
|
||
# At this stage we know that constant bounds have been propagated already, so we | ||
# only need to propagate linear constraints. | ||
for c in constraints: | ||
if isinstance(c.target, TypeVarType) and c.target.id in tvars: | ||
if c.op == SUBTYPE_OF: | ||
lower, upper = c.type_var, c.target.id | ||
else: | ||
lower, upper = c.target.id, c.type_var | ||
extras = { | ||
(l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph | ||
} | ||
graph |= extras | ||
for u in tvars: | ||
if (upper, u) in graph: | ||
lowers[u] |= lowers[lower] | ||
for l in tvars: | ||
if (l, lower) in graph: | ||
uppers[l] |= uppers[upper] | ||
return lowers, uppers | ||
|
||
|
||
def compute_dependencies( | ||
cmap: dict[TypeVarId, list[Constraint]] | ||
) -> dict[TypeVarId, list[TypeVarId F438 ]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm repeating myself, but again unit tests would be nice :-) |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: 'sor'