8000 [cherry-pick] Fix mypy plugin for 1.1.0 (#5077) (#5111) · pydantic/pydantic@9d0edbe · GitHub
[go: up one dir, main page]

Skip to content

Commit 9d0edbe

Browse files
cdce8pdmontagu
andauthored
[cherry-pick] Fix mypy plugin for 1.1.0 (#5077) (#5111)
* Fix mypy plugin for 1.1.0 (#5077) * Fix mypy plugin for 1.1.0 * Code review * Add version key to plugin data (cherry picked from commit 6267ae3) * Change file name * Add the changes from #5120 * Update changes file * Remove additional unneeded dataclass import (from #5120) --------- Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
1 parent 7f3b754 commit 9d0edbe

File tree

4 files changed

+39
-15
lines changed

4 files changed

+39
-15
lines changed

changes/5111-cdce8p.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix mypy plugin for v1.1.1, and fix `dataclass_transform` decorator for pydantic dataclasses

pydantic/dataclasses.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class M:
3232
validation without altering default `M` behaviour.
3333
"""
3434
import copy
35+
import dataclasses
3536
import sys
3637
from contextlib import contextmanager
3738
from functools import wraps
@@ -93,7 +94,7 @@ def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
9394

9495
if sys.version_info >= (3, 10):
9596

96-
@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
97+
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
9798
@overload
9899
def dataclass(
99100
*,
@@ -110,7 +111,7 @@ def dataclass(
110111
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
111112
...
112113

113-
@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
114+
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
114115
@overload
115116
def dataclass(
116117
_cls: Type[_T],
@@ -130,7 +131,7 @@ def dataclass(
130131

131132
else:
132133

133-
@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
134+
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
134135
@overload
135136
def dataclass(
136137
*,
@@ -146,7 +147,7 @@ def dataclass(
146147
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
147148
...
148149

149-
@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
150+
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
150151
@overload
151152
def dataclass(
152153
_cls: Type[_T],
@@ -164,7 +165,7 @@ def dataclass(
164165
...
165166

166167

167-
@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
168+
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
168169
def dataclass(
169170
_cls: Optional[Type[_T]] = None,
170171
*,
@@ -188,8 +189,6 @@ def dataclass(
188189
the_config = get_config(config)
189190

190191
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
191-
import dataclasses
192-
193192
should_use_proxy = (
194193
use_proxy
195194
if use_proxy is not None
@@ -328,7 +327,6 @@ def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
328327
if hasattr(self, '__post_init_post_parse__'):
329328
# We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
330329
# public method `dataclasses.fields`
331-
import dataclasses
332330

333331
# get all initvars and their default values
334332
initvars_and_values: Dict[str, Any] = {}
@@ -377,8 +375,6 @@ def create_pydantic_model_from_dataclass(
377375
config: Type[Any] = BaseConfig,
378376
dc_cls_doc: Optional[str] = None,
379377
) -> Type['BaseModel']:
380-
import dataclasses
381-
382378
field_definitions: Dict[str, Any] = {}
383379
for field in dataclasses.fields(dc_cls):
384380
default: Any = Undefined
@@ -466,8 +462,6 @@ class B(A):
466462
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
467463
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
468464
"""
469-
import dataclasses
470-
471465
return (
472466
dataclasses.is_dataclass(_cls)
473467
and not hasattr(_cls, '__pydantic_model__')

pydantic/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .fields import (
3434
MAPPING_LIKE_SHAPES,
3535
Field,
36-
FieldInfo,
3736
ModelField,
3837
ModelPrivateAttr,
3938
PrivateAttr,
@@ -118,7 +117,7 @@ def hash_function(self_: Any) -> int:
118117
_is_base_model_class_defined = False
119118

120119

121-
@dataclass_transform(kw_only_default=True, field_specifiers=(Field, FieldInfo))
120+
@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
122121
class ModelMetaclass(ABCMeta):
123122
@no_type_check # noqa C901
124123
def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901

pydantic/mypy.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
METADATA_KEY = 'pydantic-mypy-metadata'
7777
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
7878
BASESETTINGS_FULLNAME = 'pydantic.env_settings.BaseSettings'
79+
MODEL_METACLASS_FULLNAME = 'pydantic.main.ModelMetaclass'
7980
FIELD_FULLNAME = 'pydantic.fields.Field'
8081
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass'
8182

@@ -87,6 +88,9 @@ def parse_mypy_version(version: str) -> Tuple[int, ...]:
8788
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
8889
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'
8990

91+
# Increment version if plugin changes and mypy caches should be invalidated
92+
PLUGIN_VERSION = 1
93+
9094

9195
def plugin(version: str) -> 'TypingType[Plugin]':
9296
"""
@@ -102,6 +106,7 @@ class PydanticPlugin(Plugin):
102106
def __init__(self, options: Options) -> None:
103107
self.plugin_config = PydanticPluginConfig(options)
104108
self._plugin_data = self.plugin_config.to_data()
109+
self._plugin_data['version'] = PLUGIN_VERSION
105110
super().__init__(options)
106111

107112
def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]':
@@ -112,6 +117,11 @@ def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefCont
112117
return self._pydantic_model_class_maker_callback
113118
return None
114119

120+
def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
121+
if fullname == MODEL_METACLASS_FULLNAME:
122+
return self._pydantic_model_metaclass_marker_callback
123+
return None
124+
115125
def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]':
116126
sym = self.lookup_fully_qualified(fullname)
117127
if sym and sym.fullname == FIELD_FULLNAME:
@@ -139,6 +149,19 @@ def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
139149
transformer = PydanticModelTransformer(ctx, self.plugin_config)
140150
transformer.transform()
141151

152+
def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None:
153+
"""Reset dataclass_transform_spec attribute of ModelMetaclass.
154+
155+
Let the plugin handle it. This behavior can be disabled
156+
if 'debug_dataclass_transform' is set to True', for testing purposes.
157+
"""
158+
if self.plugin_config.debug_dataclass_transform:
159+
return
160+
info_metaclass = ctx.cls.info.declared_metaclass
161+
assert info_metaclass, "callback not passed from 'get_metaclass_hook'"
162+
if getattr(info_metaclass.type, 'dataclass_transform_spec', None):
163+
info_metaclass.type.dataclass_transform_spec = None # type: ignore[attr-defined]
164+
142165
def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type':
143166
"""
144167
Extract the type of the `default` argument from the Field function, and use it as the return type.
@@ -194,11 +217,18 @@ def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type':
194217

195218

196219
class PydanticPluginConfig:
197-
__slots__ = ('init_forbid_extra', 'init_typed', 'warn_required_dynamic_aliases', 'warn_untyped_fields')
220+
__slots__ = (
221+
'init_forbid_extra',
222+
'init_typed',
223+
'warn_required_dynamic_aliases',
224+
'warn_untyped_fields',
225+
'debug_dataclass_transform',
226+
)
198227
init_forbid_extra: bool
199228
init_typed: bool
200229
warn_required_dynamic_aliases: bool
201230
warn_untyped_fields: bool
231+
debug_dataclass_transform: bool # undocumented
202232

203233
def __init__(self, options: Options) -> None:
204234
if options.config_file is None: # pragma: no cover

0 commit comments

Comments
 (0)
0