From b87af27846eccc7eaeefccd0d96316ef866f9cc2 Mon Sep 17 00:00:00 2001 From: tdschper Date: Thu, 8 Dec 2022 03:44:38 -0500 Subject: [PATCH] Add error for asdict, astuple, fields, and replace in dataclasses --- mypy/plugins/dataclasses.py | 35 ++++++++++++++++++++++++- mypy/plugins/default.py | 13 +++++++-- test-data/unit/check-dataclasses.test | 26 ++++++++++++++++++ test-data/unit/lib-stub/dataclasses.pyi | 21 ++++++++++++++- 4 files changed, 91 insertions(+), 4 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 75496d5e56f9..ff89538f00f5 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -30,7 +30,7 @@ TypeVarExpr, Var, ) -from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface +from mypy.plugin import ClassDefContext, FunctionContext, SemanticAnalyzerPluginInterface from mypy.plugins.common import ( _get_decorator_bool_argument, add_attribute_to_class, @@ -631,6 +631,39 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool: return transformer.transform() +def _dataclass_exclusive_function_callback(ctx: FunctionContext, func_name: str) -> Type: + """Called for functions that should only be called on dataclasses. + + Functions are (from dataclasses module): asdict, astuple, fields, replace + """ + # Each of asdict, astuple, fields, and replace require the first argument + # to be a dataclass + arg_type = get_proper_type(ctx.arg_types[0][0]) + if isinstance(arg_type, Instance) and "dataclass" not in arg_type.type.metadata: + ctx.api.msg.fail(f"{func_name}() should be called on dataclass instances", ctx.context) + return ctx.default_return_type + + +def asdict_callback(ctx: FunctionContext) -> Type: + """Called for dataclasses.asdict.""" + return _dataclass_exclusive_function_callback(ctx, "asdict") + + +def astuple_callback(ctx: FunctionContext) -> Type: + """Called for dataclasses.astuple.""" + return _dataclass_exclusive_function_callback(ctx, "astuple") + + +def fields_callback(ctx: FunctionContext) -> Type: + """Called for dataclasses.fields.""" + return _dataclass_exclusive_function_callback(ctx, "fields") + + +def replace_callback(ctx: FunctionContext) -> Type: + """Called for dataclasses.replace.""" + return _dataclass_exclusive_function_callback(ctx, "replace") + + def _collect_field_args( expr: Expression, ctx: ClassDefContext ) -> tuple[bool, dict[str, Expression]]: diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 04971868e8f4..5d8052f4bccc 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -37,12 +37,21 @@ class DefaultPlugin(Plugin): """Type checker plugin that is enabled by default.""" def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: - from mypy.plugins import ctypes, singledispatch + from mypy.plugins import ctypes, dataclasses, singledispatch if fullname == "ctypes.Array": return ctypes.array_constructor_callback - elif fullname == "functools.singledispatch": + if fullname == "functools.singledispatch": return singledispatch.create_singledispatch_function_callback + name_pieces = fullname.split(".") + if len(name_pieces) == 2 and name_pieces[0] == "dataclasses": + callbacks: dict[str, Callable[[FunctionContext], Type]] = { + "asdict": dataclasses.asdict_callback, + "astuple": dataclasses.astuple_callback, + "fields": dataclasses.fields_callback, + "replace": dataclasses.replace_callback, + } + return callbacks.get(name_pieces[1], None) return None def get_method_signature_hook( diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index c248f8db8585..68f99ae52767 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -2001,3 +2001,29 @@ class Bar(Foo): ... e: Element[Bar] reveal_type(e.elements) # N: Revealed type is "typing.Sequence[__main__.Element[__main__.Bar]]" [builtins fixtures/dataclasses.pyi] + +[case testFuncsOnlyTakeDataclassArg] +# flags: --python-version 3.7 +# Ensure asdict, astuple, fields, and replace methods from dataclasses module only accept +# a dataclass instance as the first argument. +# See mypy issue #14215 +from dataclasses import asdict, astuple, dataclass, fields, replace + +class Klass: + pass + +@dataclass +class DClass: + pass + +klass = Klass() +dclass = DClass() +asdict(dclass) +astuple(dclass) +replace(dclass) +fields(dclass) +asdict(klass) # E: asdict() should be called on dataclass instances +astuple(klass) # E: astuple() should be called on dataclass instances +fields(klass) # E: fields() should be called on dataclass instances +replace(klass) # E: replace() should be called on dataclass instances +[builtins fixtures/dataclasses.pyi] diff --git a/test-data/unit/lib-stub/dataclasses.pyi b/test-data/unit/lib-stub/dataclasses.pyi index bd33b459266c..35146fc55480 100644 --- a/test-data/unit/lib-stub/dataclasses.pyi +++ b/test-data/unit/lib-stub/dataclasses.pyi @@ -1,4 +1,7 @@ -from typing import Any, Callable, Generic, Mapping, Optional, TypeVar, overload, Type +from typing import ( + Any, Callable, Generic, Mapping, Optional, TypeVar, overload, Tuple, + Type, +) _T = TypeVar('_T') @@ -32,3 +35,19 @@ def field(*, class Field(Generic[_T]): pass + +@overload +def asdict(obj: Any) -> dict[str, Any]: ... + +@overload +def asdict(obj: Any, *, dict_factory: Callable[[list[tuple[str, Any]]], _T]) -> _T: ... + +@overload +def astuple(obj: Any) -> Tuple[Any, ...]: ... + +@overload +def astuple(obj: Any, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ... + +def replace(obj: _T, **changes: Any) -> _T: ... + +def fields(class_or_instance: Any) -> Tuple[Field[Any], ...]: ...