8000 Allow using TypedDict for more precise typing of **kwds by ilevkivskyi · Pull Request #13471 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

Allow using TypedDict for more precise typing of **kwds #13471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add main functionality
  • Loading branch information
ilevkivskyi committed Aug 21, 2022
commit 8d107392cb1e66a94b9b2e8ef579c01a7ff94c4f
7 changes: 4 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,9 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# needs to be compatible in.
if impl_type.variables:
impl = unify_generic_callable(
impl_type,
sig1,
# Normalize both before unifying
impl_type.with_unpacked_kwargs(),
sig1.with_unpacked_kwargs(),
ignore_return=False,
return_constraint_direction=SUPERTYPE_OF,
)
Expand Down Expand Up @@ -1166,7 +1167,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: str | None) ->
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
if not isinstance(arg_type, ParamSpecType):
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
arg_type = self.named_generic_type(
"builtins.dict", [self.str_type(), arg_type]
)
Expand Down
4 changes: 4 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,8 @@ def check_callable_call(

See the docstring of check_call for more information.
"""
# Always unpack **kwargs before checking a call.
callee = callee.with_unpacked_kwargs()
if callable_name is None and callee.name:
callable_name = callee.name
ret_type = get_proper_type(callee.ret_type)
Expand Down Expand Up @@ -2057,6 +2059,8 @@ def check_overload_call(
context: Context,
) -> tuple[Type, Type]:
"""Checks a call to an overloaded function."""
# Normalize unpacked kwargs before checking the call.
callee = callee.with_unpacked_kwargs()
arg_types = self.infer_arg_types_in_empty_context(args)
# Step 1: Filter call targets to remove ones where the argument counts don't match
plausible_targets = self.plausible_overload_call_targets(
Expand Down
6 changes: 5 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,13 @@ def infer_constraints_from_protocol_members(
return res

def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# Normalize callables before matching against each other.
# Note that non-normalized callables can be created in annotations
# using e.g. callback protocols.
template = template.with_unpacked_kwargs()
if isinstance(self.actual, CallableType):
res: list[Constraint] = []
cactual = self.actual
cactual = self.actual.with_unpacked_kwargs()
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
Expand Down
18 changes: 17 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import Tuple

import mypy.typeops
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
Expand Down Expand Up @@ -141,7 +143,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:

def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
"""Return a simple least upper bound given the declared type."""
# TODO: check infinite recursion for aliases here.
# TODO: check infinite recursion for aliases here?
declaration = get_proper_type(declaration)
s = get_proper_type(s)
t = get_proper_type(t)
Expand Down Expand Up @@ -172,6 +174,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

value = t.accept(TypeJoinVisitor(s))
if declaration is None or is_subtype(value, declaration):
return value
Expand Down 8000 Expand Up @@ -229,6 +234,9 @@ def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None)
elif isinstance(t, PlaceholderType):
return AnyType(TypeOfAny.from_error)

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

# Use a visitor to handle non-trivial cases.
return t.accept(TypeJoinVisitor(s, instance_joiner))

Expand Down Expand Up @@ -528,6 +536,14 @@ def is_better(t: Type, s: Type) -> bool:
return False


def normalize_callables(s: ProperType, t: ProperType) -> Tuple[ProperType, ProperType]:
if isinstance(s, (CallableType, Overloaded)):
s = s.with_unpacked_kwargs()
if isinstance(t, (CallableType, Overloaded)):
t = t.with_unpacked_kwargs()
return s, t


def is_similar_callables(t: CallableType, s: CallableType) -> bool:
"""Return True if t and s have identical numbers of
arguments, default arguments and varargs.
Expand Down
4 changes: 4 additions & 0 deletions mypy/meet.py
67E6
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def meet_types(s: Type, t: Type) -> ProperType:
return t
if isinstance(s, UnionType) and not isinstance(t, UnionType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = join.normalize_callables(s, t)

return t.accept(TypeMeetVisitor(s))


Expand Down
5 changes: 4 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,7 +2391,10 @@ def [T <: int] f(self, x: int, y: T) -> None
name = tp.arg_names[i]
if name:
s += name + ": "
s += format_type_bare(tp.arg_types[i])
type_str = format_type_bare(tp.arg_types[i])
if tp.arg_kinds[i] == ARG_STAR2 and tp.unpack_kwargs:
type_str = f"Unpack[{type_str}]"
s += type_str
if tp.arg_kinds[i].is_optional():
s += " = ..."

Expand Down
26 changes: 26 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@
get_proper_types,
invalid_recursive_alias,
is_named_instance,
UnpackType,
)
from mypy.typevars import fill_typevars
from mypy.util import (
Expand Down Expand Up @@ -830,6 +831,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
self.defer(defn)
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
result = self.remove_unpack_kwargs(defn, result)
defn.type = result
self.add_type_alias_deps(analyzer.aliases_used)
self.check_function_signature(defn)
Expand Down Expand Up @@ -872,6 +875,29 @@ def analyze_func_def(self, defn: FuncDef) -> None:
defn.type = defn.type.copy_modified(ret_type=ret_type)
self.wrapped_coro_return_types[defn] = defn.type

def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType:
if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2:
return typ
last_type = get_proper_type(typ.arg_types[-1])
if not isinstance(last_type, UnpackType):
return typ
last_type = get_proper_type(last_type.type)
if not isinstance(last_type, TypedDictType):
self.fail("Unpack item in ** argument must be a TypedDict", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
overlap = set(typ.arg_names) & set(last_type.items)
# It is OK for TypedDict to have a key named 'kwargs'.
overlap.discard(typ.arg_names[-1])
if overlap:
overlapped = ", ".join([f'"{name}"' for name in overlap])
self.fail(f"Overlap between argument names and ** TypedDict items: {overlapped}", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
# OK, everything looks right now, mark the callable type as using unpack.
new_arg_types = typ.arg_types[:-1] + [last_type]
return typ.copy_modified(arg_types=new_arg_types, unpack_kwargs=True)

def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:
"""Check basic signature validity and tweak annotation of self/cls argument."""
# Only non-static methods are special.
Expand Down
4 changes: 4 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,10 @@ def g(x: int) -> int: ...
If the 'some_check' function is also symmetric, the two calls would be equivalent
whether or not we check the args covariantly.
"""
# Normalize both types before comparing them.
left = left.with_unpacked_kwargs()
right = right.with_unpacked_kwargs()

if is_compat_return is None:
is_compat_return = is_compat

Expand Down
34 changes: 33 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,7 @@ class CallableType(FunctionLike):
"type_guard", # T, if -> TypeGuard[T] (ret_type is bool in this case).
"from_concatenate", # whether this callable is from a concatenate object
# (this is used for error messages)
"unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable?
)

def __init__(
Expand All @@ -1613,6 +1614,7 @@ def __ 10000 init__(
def_extras: dict[str, Any] | None = None,
type_guard: Type | None = None,
from_concatenate: bool = False,
unpack_kwargs: bool = False,
) -> None:
super().__init__(line, column)
assert len(arg_types) == len(arg_kinds) == len(arg_names)
Expand Down Expand Up @@ -1653,6 +1655,7 @@ def __init__(
else:
self.def_extras = {}
self.type_guard = type_guard
self.unpack_kwargs = unpack_kwargs

def copy_modified(
self,
Expand All @@ -1674,6 +1677,7 @@ def copy_modified(
def_extras: Bogus[dict[str, Any]] = _dummy,
type_guard: Bogus[Type | None] = _dummy,
from_concatenate: Bogus[bool] = _dummy,
unpack_kwargs: Bogus[bool] = _dummy,
) -> CallableType:
return CallableType(
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
Expand All @@ -1698,6 +1702,7 @@ def copy_modified(
from_concatenate=(
from_concatenate if from_concatenate is not _dummy else self.from_concatenate
),
unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs,
)

def var_arg(self) -> FormalArgument | None:
Expand Down Expand Up @@ -1889,6 +1894,25 @@ def expand_param_spec(
variables=[*variables, *self.variables],
)

def with_unpacked_kwargs(self) -> CallableType:
if not self.unpack_kwargs:
return self.copy_modified()
last_type = get_proper_type(self.arg_types[-1])
assert isinstance(last_type, ProperType) and isinstance(last_type, TypedDictType)
extra_kinds = [
ArgKind.ARG_NAMED if name in last_type.required_keys else ArgKind.ARG_NAMED_OPT
for name in last_type.items
]
new_arg_kinds = self.arg_kinds[:-1] + extra_kinds
new_arg_names = self.arg_names[:-1] + list(last_type.items)
new_arg_types = self.arg_types[:-1] + list(last_type.items.values())
return self.copy_modified(
arg_kinds=new_arg_kinds,
arg_names=new_arg_names,
arg_types=new_arg_types,
unpack_kwargs=False,
)

def __hash__(self) -> int:
# self.is_type_obj() will fail if self.fallback.type is a FakeInfo
if isinstance(self.fallback.type, FakeInfo):
Expand Down Expand Up @@ -1940,6 +1964,7 @@ def serialize(self) -> JsonDict:
"def_extras": dict(self.def_extras),
"type_guard": self.type_guard.serialize() if self.type_guard is not None else None,
"from_concatenate": self.from_concatenate,
"unpack_kwargs": self.unpack_kwargs,
}

@classmethod
Expand All @@ -1962,6 +1987,7 @@ def deserialize(cls, data: JsonDict) -> CallableType:
deserialize_type(data["type_guard"]) if data["type_guard"] is not None else None
),
from_concatenate=data["from_concatenate"],
unpack_kwargs=data["unpack_kwargs"],
)


Expand Down Expand Up @@ -2009,6 +2035,9 @@ def with_name(self, name: str) -> Overloaded:
def get_name(self) -> str | None:
return self._items[0].name

def with_unpacked_kwargs(self) -> Overloaded:
return Overloaded([i.with_unpacked_kwargs() for i in self.items])

def accept(self, visitor: TypeVisitor[T]) -> T:
return visitor.visit_overloaded(self)

Expand Down Expand Up @@ -2917,7 +2946,10 @@ def visit_callable_type(self, t: CallableType) -> str:
name = t.arg_names[i]
if name:
s += name + ": "
s += t.arg_types[i].accept(self)
type_str = t.arg_types[i].accept(self)
if t.arg_kinds[i] == ARG_STAR2 and t.unpack_kwargs:
type_str = f"Unpack[{type_str}]"
s += type_str
if t.arg_kinds[i].is_optional():
s += " ="

Expand Down
Loading
0