8000 DRAFT: Returning TypedDict for dataclasses.asdict by syastrov · Pull Request #8339 · python/mypy · GitHub
[go: up one dir, main page]

Skip to content

DRAFT: Returning TypedDict for dataclasses.asdict #8339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
24 changes: 20 additions & 4 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Set

from mypy.nodes import (
ARG_POS, MDEF, Argument, Block, CallExpr, Expression, SYMBOL_FUNCBASE_TYPES,
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, JsonDict,
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface, CheckerPluginInterface
from mypy.semanal import set_callable_name
from mypy.types import (
CallableType, Overloaded, Type, TypeVarDef, deserialize_type, get_proper_type,
)
TypedDictType, Instance, TPDICT_FB_NAMES)
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
from mypy.typeops import try_getting_str_literals # noqa: F401 # Part of public API
Expand Down Expand Up @@ -134,8 +134,24 @@ def add_method(


def deserialize_and_fixup_type(
data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface
data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface
) -> Type:
typ = deserialize_type(data)
typ.accept(TypeFixer(api.modules, allow_missing=False))
return typ


def get_anonymous_typeddict_type(api: CheckerPluginInterface) -> Instance:
for type_fullname in TPDICT_FB_NAMES:
try:
anonymous_typeddict_type = api.named_generic_type(type_fullname, [])
if anonymous_typeddict_type is not None:
return anonymous_typeddict_type
except KeyError:
continue
raise RuntimeError("No TypedDict fallback type found")


def make_anonymous_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, Type]',
required_keys: Set[str]) -> TypedDictType:
return TypedDictType(fields, required_keys=required_keys, fallback=get_anonymous_typeddict_type(api))
98 changes: 93 additions & 5 deletions mypy/plugins/dataclasses.py
CD3F
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
"""Plugin that provides support for dataclasses."""

from typing import Dict, List, Set, Tuple, Optional
from collections import OrderedDict
from typing import Dict, List, Set, Tuple, Optional, FrozenSet, Callable

from typing_extensions import Final

from mypy.maptype import map_instance_to_supertype
from mypy.nodes import (
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
Context, Expression, JsonDict, NameExpr, RefExpr,
SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugin import ClassDefContext, FunctionContext, CheckerPluginInterface
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import add_method, _get_decorator_bool_argument, make_anonymous_typeddict
from mypy.plugins.common import (
add_method, _get_decorator_bool_argument, deserialize_and_fixup_type,
deserialize_and_fixup_type,
)
from mypy.types import Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type
from mypy.server.trigger import make_wildcard_trigger
from mypy.typeops import tuple_fallback
from mypy.types import Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type, Type, TupleType, UnionType, \
AnyType, TypeOfAny

# The set of decorators that generate dataclasses.
dataclass_makers = {
Expand All @@ -24,6 +31,10 @@
SELF_TVAR_NAME = '_DT' # type: Final


def is_type_dataclass(info: TypeInfo) -> bool:
return 'dataclass' in info.metadata


class DataclassAttribute:
def __init__(
self,
Expand Down Expand Up @@ -297,7 +308,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
# we'll have unmodified attrs laying around.
all_attrs = attrs.copy()
for info in cls.info.mro[1:-1]:
if 'dataclass' not in info.metadata:
if not is_type_dataclass(info):
continue

super_attrs = []
Expand Down Expand Up @@ -386,3 +397,80 @@ def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
args[name] = arg
return True, args
return False, {}


def asdict_callback(ctx: FunctionContext) -> Type:
positional_arg_types = ctx.arg_types[0]

if positional_arg_types:
if len(ctx.arg_types) == 2:
# We can't infer a more precise for calls where dict_factory is set.
# At least for now, typeshed stubs for asdict don't allow you to pass in `dict` as dict_factory,
# so we can't special-case that.
return ctx.default_return_type
dataclass_instance = positional_arg_types[0]
if isinstance(dataclass_instance, Instance):
info = dataclass_instance.type
if not is_type_dataclass(info):
ctx.api.fail('asdict() should be called on dataclass instances', dataclass_instance)
return _type_asdict(ctx.api, ctx.context, dataclass_instance)
return ctx.default_return_type


def _transform_type_args(*, typ: Instance, transform: Callable[[Instance], Type]) -> \
List[Type]:
"""For each type arg used in the Instance, call transform function on it if the arg is an Instance."""
return [transform(arg) if isinstance(arg, Instance) else arg for arg in typ.args]


def _type_asdict(api: CheckerPluginInterface, context: Context, typ: Type) -> Type:
"""Convert dataclasses into TypedDicts, recursively looking into built-in containers.

It will look for dataclasses inside of tuples, lists, and dicts and convert them to TypedDicts.
"""

def _type_asdict_inner(typ: Type, seen_dataclasses: FrozenSet[str]) -> Type:
if isinstance(typ, UnionType):
return UnionType([_type_asdict_inner(item, seen_dataclasses) for item in typ.items])
if isinstance(typ, Instance):
info = typ.type
if is_type_dataclass(info):
if info.fullname in seen_dataclasses:
api.fail("Recursive types are not supported in call to asdict, so falling back to Dict[str, Any]",
context)
# Note: Would be nicer to fallback to default_return_type, but that is Any (due to overloads?)
return api.named_generic_type('builtins.dict', [api.named_generic_type('builtins.str', []),
AnyType(TypeOfAny.implementation_artifact)])
seen_dataclasses |= {info.fullname}
attrs = info.metadata['dataclass']['attributes']
fields = OrderedDict() # type: OrderedDict[str, Type]
for data in attrs:
# TODO: DataclassAttribute.deserialize takes SemanticAnalyzerPluginInterface but we have
# CheckerPluginInterface here.
attr = DataclassAttribute.deserialize(info, data, api)
sym_node = info.names[attr.name]
typ = sym_node.type
assert typ is not None
fields[attr.name] = _type_asdict_inner(typ, seen_dataclasses)
return make_anonymous_typeddict(api, fields=fields, required_keys=set(fields.keys()))
elif info.has_base('builtins.list'):
supertype_instance = map_instance_to_supertype(typ, api.named_generic_type('builtins.list', []).type)
new_args = _transform_type_args(
typ=supertype_instance,
transform=lambda arg: _type_asdict_inner(arg, seen_dataclasses)
)
return api.named_generic_type('builtins.list', new_args)
elif info.has_base('builtins.dict'):
supertype_instance = map_instance_to_supertype(typ, api.named_generic_type('builtins.dict', []).type)
new_args = _transform_type_args(
typ=supertype_instance,
transform=lambda arg: _type_asdict_inner(arg, seen_dataclasses)
)
return api.named_generic_type('builtins.dict', new_args)
elif isinstance(typ, TupleType):
# TODO: Support subclasses/namedtuples properly
return TupleType([_type_asdict_inner(item, seen_dataclasses) for item in typ.items],
tuple_fallback(typ), implicit=typ.implicit)
return typ

return _type_asdict_inner(typ, seen_dataclasses=frozenset())
3 changes: 3 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ class DefaultPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
from mypy.plugins import ctypes
from mypy.plugins import dataclasses

if fullname == 'contextlib.contextmanager':
return contextmanager_callback
elif fullname == 'builtins.open' and self.python_version[0] == 3:
return open_callback
elif fullname == 'ctypes.Array':
return ctypes.array_constructor_callback
elif fullname == 'dataclasses.asdict':
return dataclasses.asdict_callback
return None

def get_method_signature_hook(self, fullname: str
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeshed
Submodule typeshed updated 783 files
Loading
0