8000 Fix explicit type for partial (#17424) · python/mypy@1b116df · GitHub
[go: up one dir, main page]

Skip to content

Commit 1b116df

Browse files
authored
Fix explicit type for partial (#17424)
Fixes #17301
1 parent abdaf6a commit 1b116df

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

mypy/plugins/functools.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import mypy.checker
88
import mypy.plugin
9+
import mypy.semanal
910
from mypy.argmap import map_actuals_to_formals
1011
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
1112
from mypy.plugins.common import add_method_to_class
@@ -24,6 +25,8 @@
2425

2526
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
2627

28+
PARTIAL = "functools.partial"
29+
2730

2831
class _MethodInfo(NamedTuple):
2932
is_static: bool
@@ -142,7 +145,8 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
142145
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
143146
)
144147
for k in fn_type.arg_kinds
145-
]
148+
],
149+
ret_type=ctx.api.named_generic_type(PARTIAL, [fn_type.ret_type]),
146150
)
147151
if defaulted.line < 0:
148152
# Make up a line number if we don't have one
@@ -188,6 +192,13 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
188192
bound = get_proper_type(bound)
189193
if not isinstance(bound, CallableType):
190194
return ctx.default_return_type
195+
wrapped_ret_type = get_proper_type(bound.ret_type)
196+
if not isinstance(wrapped_ret_type, Instance) or wrapped_ret_type.type.fullname != PARTIAL:
197+
return ctx.default_return_type
198+
if not mypy.semanal.refers_to_fullname(ctx.args[0][0], PARTIAL):
199+
# If the first argument is partial, above call will trigger the plugin
200+
# again, in between the wrapping above an unwrapping here.
201+
bound = bound.copy_modified(ret_type=wrapped_ret_type.args[0])
191202

192203
formal_to_actual = map_actuals_to_formals(
193204
actual_kinds=actual_arg_kinds,
@@ -237,7 +248,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
237248
ret_type=ret_type,
238249
)
239250

240-
ret = ctx.api.named_generic_type("functools.partial", [ret_type])
251+
ret = ctx.api.named_generic_type(PARTIAL, [ret_type])
241252
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
242253
return ret
243254

@@ -247,7 +258,7 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
247258
if (
248259
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
249260
or not isinstance(ctx.type, Instance)
250-
or ctx.type.type.fullname != "functools.partial"
261+
or ctx.type.type.fullname != PARTIAL
251262
or not ctx.type.extra_attrs
252263
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
253264
):

test-data/unit/check-functools.test

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,37 @@ reveal_type(functools.partial(fn3, 2)()) # E: "str" not callable \
347347
# E: Argument 1 to "partial" has incompatible type "Union[Callable[[int], int], str]"; expected "Callable[..., int]"
348348
[builtins fixtures/tuple.pyi]
349349

350+
[case testFunctoolsPartialExplicitType]
351+
from functools import partial
352+
from typing import Type, TypeVar, Callable
353+
354+
T = TypeVar("T")
355+
def generic(string: str, integer: int, resulting_type: Type[T]) -> T: ...
356+
357+
p: partial[str] = partial(generic, resulting_type=str)
358+
q: partial[bool] = partial(generic, resulting_type=str) # E: Argument "resulting_type" to "generic" has incompatible type "Type[str]"; expected "Type[bool]"
359+
360+
pc: Callable[..., str] = partial(generic, resulting_type=str)
361+
qc: Callable[..., bool] = partial(generic, resulting_type=str) # E: Incompatible types in assignment (expression has type "partial[str]", variable has type "Callable[..., bool]") \
362+
# N: "partial[str].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], str]"
363+
[builtins fixtures/tuple.pyi]
364+
365+
[case testFunctoolsPartialNestedPartial]
366+
from functools import partial
367+
from typing import Any
368+
369+
def foo(x: int) -> int: ...
370+
p = partial(partial, foo)
371+
reveal_type(p()(1)) # N: Revealed type is "builtins.int"
372+
p()("no") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
373+
374+
q = partial(partial, partial, foo)
375+
q()()("no") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
376+
377+
r = partial(partial, foo, 1)
378+
reveal_type(r()()) # N: Revealed type is "builtins.int"
379+
[builtins fixtures/tuple.pyi]
380+
350381
[case testFunctoolsPartialTypeObject]
351382
import functools
352383
from typing import Type, Generic, TypeVar

0 commit comments

Comments
 (0)
0