8000 Foundations for non-linear solver and polymorphic application by ilevkivskyi · Pull Request #15287 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Foundations for non-linear solver and polymorphic application #15287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make some progress
  • Loading branch information
ilevkivskyi committed May 19, 2023
commit eb3a1e109732191c0750493f9e69db96dd0d1da9
106 changes: 2 additions & 104 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,12 @@
Callable,
ClassVar,
Dict,
Iterable,
Iterator,
Mapping,
NamedTuple,
NoReturn,
Sequence,
TextIO,
TypeVar,
)
from typing_extensions import Final, TypeAlias as _TypeAlias

Expand All @@ -47,6 +45,7 @@
import mypy.semanal_main
from mypy.checker import TypeChecker
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.indirection import TypeIndirectionVisitor
from mypy.messages import MessageBuilder
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable, TypeInfo
Expand Down Expand Up @@ -3465,15 +3464,8 @@ def sorted_components(
edges = {id: deps_filtered(graph, vertices, id, pri_max) for id in vertices}
sccs = list(strongly_connected_components(vertices, edges))
# Topsort.
sccsmap = {id: frozenset(scc) for scc in sccs for id in scc}
data: dict[AbstractSet[str], set[AbstractSet[str]]] = {}
for scc in sccs:
deps: set[AbstractSet[str]] = set()
for id in scc:
deps.update(sccsmap[x] for x in deps_filtered(graph, vertices, id, pri_max))
data[frozenset(scc)] = deps
res = []
for ready in topsort(data):
for ready in topsort(prepare_sccs(sccs, edges)):
# Sort the sets in ready by reversed smallest State.order. Examples:
#
# - If ready is [{x}, {y}], x.order == 1, y.order == 2, we get
Expand All @@ -3498,100 +3490,6 @@ def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: in
]


def strongly_connected_components(
vertices: AbstractSet[str], edges: dict[str, list[str]]
) -> Iterator[set[str]]:
"""Compute Strongly Connected Components of a directed graph.

Args:
vertices: the labels for the vertices
edges: for each vertex, gives the target vertices of its outgoing edges

Returns:
An iterator yielding strongly connected components, each
represented as a set of vertices. Each input vertex will occur
exactly once; vertices not part of a SCC are returned as
singleton sets.

From https://code.activestate.com/recipes/578507/.
"""
identified: set[str] = set()
stack: list[str] = []
index: dict[str, int] = {}
boundaries: list[int] = []

def dfs(v: str) -> Iterator[set[str]]:
index[v] = len(stack)
stack.append(v)
boundaries.append(index[v])

for w in edges[v]:
if w not in index:
yield from dfs(w)
elif w not in identified:
while index[w] < boundaries[-1]:
boundaries.pop()

if boundaries[-1] == index[v]:
boundaries.pop()
scc = set(stack[index[v] :])
del stack[index[v] :]
identified.update(scc)
yield scc

for v in vertices:
if v not in index:
yield from dfs(v)


T = TypeVar("T")


def topsort(data: dict[T, set[T]]) -> Iterable[set[T]]:
"""Topological sort.

Args:
data: A map from vertices to all vertices that it has an edge
connecting it to. NOTE: This data structure
is modified in place -- for normalization purposes,
self-dependencies are removed and entries representing
orphans are added.

Returns:
An iterator yielding sets of vertices that have an equivalent
ordering.

Example:
Suppose the input has the following structure:

{A: {B, C}, B: {D}, C: {D}}

This is normalized to:

{A: {B, C}, B: {D}, C: {D}, D: {}}

The algorithm will yield the following values:

{D}
{B, C}
{A}

From https://code.activestate.com/recipes/577413/.
"""
# TODO: Use a faster algorithm?
for k, v in data.items():
v.discard(k) # Ignore self dependencies.
for item in set.union(*data.values()) - set(data.keys()):
data[item] = set()
while True:
ready = {item for item, dep in data.items() if not dep}
if not ready:
break
yield ready
data = {item: (dep - ready) for item, dep in data.items() if item not in ready}
assert not data, f"A cyclic dependency exists amongst {data!r}"


def missing_stubs_file(cache_dir: str) -> str:
return os.path.join(cache_dir, "missing_stubs")

Expand Down
41 changes: 28 additions & 13 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import mypy.errorcodes as codes
from mypy import applytype, erasetype, join, message_registry, nodes, operators, types
from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
from mypy.checkmember import analyze_member_access, type_object_type
from mypy.checkmember import analyze_member_access, freeze_all_type_vars, type_object_type
from mypy.checkstrformat import StringFormatterChecker
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
from mypy.errors import ErrorWatcher, report_internal_error
Expand Down Expand Up @@ -115,13 +115,14 @@
false_only,
fixup_partial_type,
function_type,
get_type_vars,
is_literal_type_like,
make_simplified_union,
simple_literal_type,
true_only,
try_expanding_sum_type_to_union,
try_getting_str_literals,
tuple_fallback, get_type_vars,
tuple_fallback,
)
from mypy.types import (
LITERAL_TYPE_NAMES,
Expand All @@ -147,6 +148,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
Expand All @@ -156,7 +158,7 @@
get_proper_type,
get_proper_types,
has_recursive_types,
is_named_instance, TypeVarLikeType,
is_named_instance,
)
from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional
from mypy.typestate import type_state
Expand Down Expand Up @@ -1791,8 +1793,10 @@ def infer_function_type_arguments(
elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg):
self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context)

# TODO: Filter away ParamSpec
if any(a is None or isinstance(a, UninhabitedType) for a in inferred_args):
# TODO: Filter away (or handle) ParamSpec?
if any(
a is None or isinstance(get_proper_type(a), UninhabitedType) for a in inferred_args
):
poly_inferred_args = infer_function_type_arguments(
callee_type,
arg_types,
Expand All @@ -1802,15 +1806,21 @@ def infer_function_type_arguments(
strict=self.chk.in_checked_function(),
allow_polymorphic=True,
)
for i, arg in enumerate(get_proper_types(poly_inferred_args)):
if isinstance(arg, (NoneType, UninhabitedType)) or has_erased_component(arg):
for i, pa in enumerate(get_proper_types(poly_inferred_args)):
# TODO: can we be more principled here?
if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa):
poly_inferred_args[i] = None
poly_callee_type = self.apply_generic_arguments(callee_type, poly_inferred_args, context)
poly_callee_type = self.apply_generic_arguments(
callee_type, poly_inferred_args, context
)
yes_vars = poly_callee_type.variables
no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables}
if not set(get_type_vars(poly_callee_type)) & no_vars:
applied = apply_poly(poly_callee_type, yes_vars)
if applied is not None:
if applied is not None and poly_inferred_args != [None] * len(
poly_inferred_args
):
freeze_all_type_vars(applied)
return applied
else:
# In dynamically typed functions use implicit 'Any' types for
Expand Down Expand Up @@ -5313,7 +5323,7 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
return c.copy_modified(ret_type=new_ret_type)


def apply_poly(tp: CallableType, poly_tvars: list[TypeVarLikeType]) -> Optional[CallableType]:
def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optional[CallableType]:
try:
return tp.copy_modified(
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
Expand All @@ -5329,9 +5339,9 @@ class PolyTranslationError(TypeError):


class PolyTranslator(TypeTranslator):
def __init__(self, poly_tvars: list[TypeVarLikeType]) -> None:
def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
self.poly_tvars = set(poly_tvars)
self.bound_tvars = set()
self.bound_tvars: set[TypeVarLikeType] = set()

def visit_callable_type(self, t: CallableType) -> Type:
found_vars = set()
Expand All @@ -5342,15 +5352,20 @@ def visit_callable_type(self, t: CallableType) -> Type:
self.bound_tvars |= found_vars
result = super().visit_callable_type(t)
self.bound_tvars -= found_vars
assert isinstance(result, ProperType)
assert isinstance(result, CallableType)
result.variables += list(found_vars)
result.variables = list(result.variables) + list(found_vars)
return result

def visit_type_var(self, t: TypeVarType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> Type:
# TODO: more careful here (also handle TypeVarTupleType)
raise PolyTranslationError()

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
return t.copy_modified(args=[a.accept(self) for a in t.args])

Expand Down
66 changes: 12 additions & 54 deletions mypy/constraints.py
8E6B
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from mypy.erasetype import erase_typevars
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import ARG_OPT, ARG_POS, CONTRAVARIANT, COVARIANT, ArgKind
from mypy.type_visitor import BoolTypeQuery, ANY_STRATEGY
from mypy.types import (
TUPLE_LIKE_INSTANCE_NAMES,
AnyType,
Expand Down Expand Up @@ -64,48 +63,6 @@
SUPERTYPE_OF: Final = 1


def flatten_types(tls: list[list[Type]]) -> list[Type]:
res = []
for tl in tls:
res.extend(tl)
return res


class PolyExtractor(TypeQuery[list[TypeVarLikeType]]):
def __init__(self) -> None:
super().__init__(flatten_types)

def visit_callable_type(self, t: CallableType) -> list[TypeVarLikeType]:
return t.variables + super().visit_callable_type(t)


class PolyLeakDetector(BoolTypeQuery):
def __init__(self, found: set[TypeVarLikeType]) -> None:
super().__init__(ANY_STRATEGY)
self.bound = set()
self.found = found

def visit_callable_type(self, t: CallableType) -> bool:
self.bound |= set(t.variables)
result = super().visit_callable_type(t)
self.bound -= set(t.variables)
return result

def visit_type_var(self, t: TypeVarType) -> bool:
return t in self.found and t not in self.bound


def sanitize_constraints(constraints: list[Constraint], types: list[Type]) -> list[Constraint]:
res = []
found = set()
for tp in types:
found |= set(tp.accept(PolyExtractor()))
for c in constraints:
if not c.target.accept(PolyLeakDetector(found)):
res.append(c)
return res


class Constraint:
"""A representation of a type constraint.

Expand Down Expand Up @@ -211,9 +168,6 @@ def infer_constraints_for_callable(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
p_arg = get_proper_type(callee.arg_types[i])
if not isinstance(p_arg, CallableType) or p_arg.param_spec() is None:
c = sanitize_constraints(c, [callee.arg_types[i], actual_type])
constraints.extend(c)

return constraints
Expand Down Expand Up @@ -933,17 +887,21 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
if isinstance(self.actual, CallableType):
res: list[Constraint] = []
cactual = self.actual.with_unpacked_kwargs()
if cactual.variables and self.direction == SUPERTYPE_OF and template.param_spec() is None:
from mypy.subtypes import unify_generic_callable

unified = unify_generic_callable(cactual, template, ignore_return=True)
if unified is not None:
cactual = unified
res.extend(infer_constraints(cactual, template, neg_op(self.direction)))
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
# FIX what if one of the functions is generic
# TODO: Erase template vars if generic?
if (
cactual.variables
and self.direction == SUPERTYPE_OF
and cactual.param_spec() is None
):
from mypy.subtypes import unify_generic_callable

unified = unify_generic_callable(cactual, template, ignore_return=True)
if unified is not None:
cactual = unified
res.extend(infer_constraints(cactual, template, neg_op(self.direction)))

# We can't infer constraints from arguments if the template is Callable[..., T]
# (with literal '...').
Expand Down
Loading
0