8000 Foundations for non-linear solver and polymorphic application by ilevkivskyi · Pull Request #15287 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Foundations for non-linear solver and polymorphic application #15287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Start working on generic stuff < 8000 /div>
  • Loading branch information
ilevkivskyi committed May 17, 2023
commit 08a88154567964c4079b2657e7e017a011dbbf4e
69 changes: 67 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from mypy.state import state
from mypy.subtypes import is_equivalent, is_same_type, is_subtype, non_method_protocol_members
from mypy.traverser import has_await_expression
from mypy.type_visitor import TypeTranslator
from mypy.typeanal import (
check_for_explicit_any,
has_any_from_unimported_type,
Expand All @@ -120,7 +121,7 @@
true_only,
try_expanding_sum_type_to_union,
try_getting_str_literals,
tuple_fallback,
tuple_fallback, get_type_vars,
)
from mypy.types import (
LITERAL_TYPE_NAMES,
Expand Down Expand Up @@ -155,7 +156,7 @@
get_proper_type,
get_proper_types,
has_recursive_types,
is_named_instance,
is_named_instance, TypeVarLikeType,
)
from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional
from mypy.typestate import type_state
Expand Down Expand Up @@ -1789,6 +1790,28 @@ def infer_function_type_arguments(
inferred_args[0] = self.named_type("builtins.str")
elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg):
self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context)

# TODO: Filter away ParamSpec
if any(a is None or isinstance(a, UninhabitedType) for a in inferred_args):
poly_inferred_args = infer_function_type_arguments(
callee_type,
arg_types,
arg_kinds,
formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function(),
allow_polymorphic=True,
)
for i, arg in enumerate(get_proper_types(poly_inferred_args)):
if isinstance(arg, (NoneType, UninhabitedType)) or has_erased_component(arg):
poly_inferred_args[i] = None
poly_callee_type = self.apply_generic_arguments(callee_type, poly_inferred_args, context)
yes_vars = poly_callee_type.variables
no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables}
if not set(get_type_vars(poly_callee_type)) & no_vars:
applied = apply_poly(poly_callee_type, yes_vars)
if applied is not None:
return applied
else:
# In dynamically typed functions use implicit 'Any' types for
# type variables.
Expand Down Expand Up @@ -5290,6 +5313,48 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
return c.copy_modified(ret_type=new_ret_typ 8000 e)


def apply_poly(tp: CallableType, poly_tvars: list[TypeVarLikeType]) -> Optional[CallableType]:
try:
return tp.copy_modified(
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
variables=[],
)
except PolyTranslationError:
return None


class PolyTranslationError(TypeError):
pass


class PolyTranslator(TypeTranslator):
def __init__(self, poly_tvars: list[TypeVarLikeType]) -> None:
self.poly_tvars = set(poly_tvars)
self.bound_tvars = set()

def visit_callable_type(self, t: CallableType) -> Type:
found_vars = set()
for arg in t.arg_types:
found_vars |= set(get_type_vars(arg))
found_vars &= self.poly_tvars
found_vars -= self.bound_tvars
self.bound_tvars |= found_vars
result = super().visit_callable_type(t)
self.bound_tvars -= found_vars
assert isinstance(result, CallableType)
result.variables += list(found_vars)
return result

def visit_type_var(self, t: TypeVarType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var(t)

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
return t.copy_modified(args=[a.accept(self) for a in t.args])


class ArgInferSecondPassQuery(types.BoolTypeQuery):
"""Query whether an argument type should be inferred in the second pass.

Expand Down
53 changes: 53 additions & 0 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mypy.erasetype import erase_typevars
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import ARG_OPT, ARG_POS, CONTRAVARIANT, COVARIANT, ArgKind
from mypy.type_visitor import BoolTypeQuery, ANY_STRATEGY
from mypy.types import (
TUPLE_LIKE_INSTANCE_NAMES,
AnyType,
Expand Down Expand Up @@ -63,6 +64,48 @@
SUPERTYPE_OF: Final = 1


def flatten_types(tls: list[list[Type]]) -> list[Type]:
res = []
for tl in tls:
res.extend(tl)
return res


class PolyExtractor(TypeQuery[list[TypeVarLikeType]]):
def __init__(self) -> None:
super().__init__(flatten_types)

def visit_callable_type(self, t: CallableType) -> list[TypeVarLikeType]:
return t.variables + super().visit_callable_type(t)


class PolyLeakDetector(BoolTypeQuery):
def __init__(self, found: set[TypeVarLikeType]) -> None:
super().__init__(ANY_STRATEGY)
self.bound = set()
self.found = found

def visit_callable_type(self, t: CallableType) -> bool:
self.bound |= set(t.variables)
result = super().visit_callable_type(t)
self.bound -= set(t.variables)
return result

def visit_type_var(self, t: TypeVarType) -> bool:
return t in self.found and t not in self.bound


def sanitize_constraints(constraints: list[Constraint], types: list[Type]) -> list[Constraint]:
res = []
found = set()
for tp in types:
found |= set(tp.accept(PolyExtractor()))
for c in constraints:
if not c.target.accept(PolyLeakDetector(found)):
res.append(c)
return res


class Constraint:
"""A representation of a type constraint.

Expand Down Expand Up @@ -168,6 +211,9 @@ def infer_constraints_for_callable(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
p_arg = get_proper_type(callee.arg_types[i])
if not isinstance(p_arg, CallableType) or p_arg.param_spec() is None:
c = sanitize_constraints(c, [callee.arg_types[i], actual_type])
constraints.extend(c)

return constraints
Expand Down Expand Up @@ -887,6 +933,13 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
if isinstance(self.actual, CallableType):
res: list[Constraint] = []
cactual = self.actual.with_unpacked_kwargs()
if cactual.variables and self.direction == SUPERTYPE_OF and template.param_spec() is None:
from mypy.subtypes import unify_generic_callable

unified = unify_generic_callable(cactual, template, ignore_return=True)
if unified is not None:
cactual = unified
res.extend(infer_constraints(cactual, template, neg_op(self.direction)))
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
Expand Down
12 changes: 9 additions & 3 deletions mypy/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
SUBTYPE_OF,
SUPERTYPE_OF,
infer_constraints,
infer_constraints_for_callable,
infer_constraints_for_callable, Constraint, sanitize_constraints,
)
from mypy.nodes import ArgKind
from mypy.solve import solve_constraints
from mypy.types import CallableType, Instance, Type, TypeVarId
from mypy.type_visitor import TypeQuery
from mypy.typeops import get_type_vars
from mypy.types import CallableType, Instance, Type, TypeVarId, TypeVarLikeType, ParamSpecType, get_proper_type


class ArgumentInferContext(NamedTuple):
Expand All @@ -36,6 +38,7 @@ def infer_function_type_arguments(
formal_to_actual: list[list[int]],
context: ArgumentInferContext,
strict: bool = True,
allow_polymorphic: bool = False,
) -> list[Type | None]:
"""Infer the type arguments of a generic function.

Expand All @@ -57,7 +60,7 @@ def infer_function_type_arguments(

# Solve constraints.
type_vars = callee_type.type_var_ids()
return solve_constraints(type_vars, constraints, strict)
return solve_constraints(type_vars, constraints, strict, allow_polymorphic)


def infer_type_arguments(
Expand All @@ -66,4 +69,7 @@ def infer_type_arguments(
# Like infer_function_type_arguments, but only match a single type
# against a generic type.
constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF)
p_template = get_proper_type(template)
if not isinstance(p_template, CallableType) or p_template.param_spec() is None:
constraints = sanitize_constraints(constraints, [template, actual])
return solve_constraints(type_var_ids, constraints)
32 changes: 30 additions & 2 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from collections import defaultdict

from mypy.constraints import SUPERTYPE_OF, Constraint
from mypy.constraints import SUPERTYPE_OF, Constraint, neg_op
from mypy.join import join_types
from mypy.meet import meet_types
from mypy.subtypes import is_subtype
from mypy.typeanal import remove_dups
from mypy.typeops import get_type_vars
from mypy.types import (
AnyType,
ProperType,
Expand All @@ -17,12 +19,27 @@
UninhabitedType,
UnionType,
get_proper_type,
ParamSpecType,
TypeVarType,
)
from mypy.typestate import type_state


def remove_mirror(constraints: list[Constraint]) -> list[Constraint]:
seen = set()
result = []
for c in constraints:
if isinstance(c.target, TypeVarType):
if (c.target.id, neg_op(c.op), c.type_var) in seen:
continue
seen.add((c.type_var, c.op, c.target.id))
result.append(c)
return result


def solve_constraints(
vars: list[TypeVarId], constraints: list[Constraint], strict: bool = True
vars: list[TypeVarId], constraints: list[Constraint], strict: bool = True,
allow_polymorphic: bool = False,
) -> list[Type | None]:
"""Solve type constraints.

Expand All @@ -33,12 +50,19 @@ def solve_constraints(
pick NoneType as the value of the type variable. If strict=False,
pick AnyType.
"""
constraints = remove_dups(constraints)
constraints = remove_mirror(constraints)

# Collect a list of constraints for each type variable.
cmap: dict[TypeVarId, list[Constraint]] = defaultdict(list)
for con in constraints:
cmap[con.type_var].append(con)

res: list[Type | None] = []
if allow_polymorphic:
extra: set[TypeVarId] = set()
else:
extra = set(vars)

# Solve each type variable separately.
for tvar in vars:
Expand All @@ -50,6 +74,10 @@ def solve_constraints(
# bounds based on constraints. Note that we assume that the constraint
# targets do not have constraint references.
for c in cmap.get(tvar, []):
if set(t.id for t in get_type_vars(c.target)) & ({tvar} | extra):
if not isinstance(c.origin_type_var, ParamSpecType):
# TODO: figure out def [U] (U) -> U vs itself
continue
if c.op == SUPERTYPE_OF:
if bottom is None:
bottom = c.target
Expand Down
2 changes: 1 addition & 1 deletion mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mypy.constraints
import mypy.typeops< 10000 /span>
from mypy.erasetype import erase_type
from mypy.expandtype import expand_self_type, expand_type_by_instance
from mypy.expandtype import expand_self_type, expand_type_by_instance, freshen_function_type_vars
from mypy.maptype import map_instance_to_supertype

# Circular import; done in the function instead.
Expand Down
42 changes: 42 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -2733,3 +2733,45 @@ dict1: Any
dict2 = {"a": C1(), **{x: C2() for x in dict1}}
reveal_type(dict2) # N: Revealed type is "builtins.dict[Any, __main__.B]"
[builtins fixtures/dict.pyi]

[case testGenericStuff]
from typing import TypeVar, Callable, List

X = TypeVar('X')
T = TypeVar('T')

def foo(x: Callable[[int], X]) -> List[X]:
...
def id(x: T) -> T:
...
y = foo(id)
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
[builtins fixtures/list.pyi]

[case testHardGenericStuff]
from typing import TypeVar, Callable, List, Sequence

S = TypeVar('S')
T = TypeVar('T')
U = TypeVar('U')

def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]:
...
def id(x: U) -> U:
...
g = dec(id)
reveal_type(g) # N:
reveal_type(g(42))

def comb(f: Callable[[T], S], g: Callable[[S], U]) -> Callable[[T], U]: ...
reveal_type(comb(id, id))

def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]:
def inner(x: S) -> List[T]:
return [f(x) for f in fs]
return inner

fs = [id, id, id]
reveal_type(mix(fs))
reveal_type(mix([id, id, id]))
[builtins fixtures/list.pyi]
14 changes: 14 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1520,3 +1520,17 @@ def identity(func: Callable[P, None]) -> Callable[P, None]: ...
@identity
def f(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ...
[builtins fixtures/paramspec.pyi]

[case testParamSpecFoo]
from typing import Callable, List, TypeVar
4FB0 from typing_extensions import ParamSpec

P = ParamSpec("P")
T = TypeVar("T")
U = TypeVar("U")

def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ...
def test(x: U) -> U: ...
reveal_type(dec)
reveal_type(dec(test))
[builtins fixtures/paramspec.pyi]
0