8000 Various typing improvements to fields module (#2723) · marshmallow-code/marshmallow@447eb86 · GitHub
[go: up one dir, main page]

Skip to content

Commit 447eb86

Browse files
authored
Various typing improvements to fields module (#2723)
* Various typing improvements to fields module * Update changelog
1 parent c188cdb commit 447eb86

File tree

3 files changed

+64
-43
lines changed

3 files changed

+64
-43
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Changelog
66

77
Features:
88

9+
- Typing: Improve typings in `marshmallow.fields` (:pr:`2723`).
910
- Typing: Replace type comments with inline typings (:pr:`2718`).
1011

1112
Bug fixes:

src/marshmallow/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@
1616
class FieldABC(ABC):
1717
"""Abstract base class from which all Field classes inherit."""
1818

19-
parent = None
20-
name = None
21-
root = None
22-
2319
@abstractmethod
2420
def serialize(self, attr, obj, accessor=None):
2521
pass

src/marshmallow/fields.py

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@
7575
"Pluck",
7676
]
7777

78-
_T = typing.TypeVar("_T")
79-
8078

8179
class Field(FieldABC):
8280
"""Basic field from which other fields should extend. It applies no
@@ -132,7 +130,7 @@ class Field(FieldABC):
132130
#: Default error messages for various kinds of errors. The keys in this dictionary
133131
#: are passed to `Field.make_error`. The values are error messages passed to
134132
#: :exc:`marshmallow.exceptions.ValidationError`.
135-
default_error_messages = {
133+
default_error_messages: dict[str, str] = {
136134
"required": "Missing data for required field.",
137135
"null": "Field may not be null.",
138136
"validator_failed": "Invalid value.",
@@ -224,6 +222,10 @@ def __init__(
224222
messages.update(error_messages or {})
225223
self.error_messages = messages
226224

225+
self.parent: Field | Schema | None = None
226+
self.name: str | None = None
227+
self.root: Schema | None = None
228+
227229
def __repr__(self) -> str:
228230
return (
229231
f"<fields.{self.__class__.__name__}(dump_default={self.dump_default!r}, "
@@ -237,7 +239,15 @@ def __repr__(self) -> str:
237239
def __deepcopy__(self, memo):
238240
return copy.copy(self)
239241

240-
def get_value(self, obj, attr, accessor=None, default=missing_):
242+
def get_value(
243+
self,
244+
obj: typing.Any,
245+
attr: str,
246+
accessor: (
247+
typing.Callable[[typing.Any, str, typing.Any], typing.Any] | None
248+
) = None,
249+
default: typing.Any = missing_,
250+
):
241251
"""Return the value for a given key from an object.
242252
243253
:param object obj: The object to get the value from.
@@ -249,14 +259,14 @@ def get_value(self, obj, attr, accessor=None, default=missing_):
249259
check_key = attr if self.attribute is None else self.attribute
250260
return accessor_func(obj, check_key, default)
251261

252-
def _validate(self, value):
262+
def _validate(self, value: typing.Any):
253263
"""Perform validation on ``value``. Raise a :exc:`ValidationError` if validation
254264
does not succeed.
255265
"""
256266
self._validate_all(value)
257267

258268
@property
259-
def _validate_all(self):
269+
def _validate_all(self) -> typing.Callable[[typing.Any], None]:
260270
return And(*self.validators, error=self.error_messages["validator_failed"])
261271

262272
def make_error(self, key: str, **kwargs) -> ValidationError:
@@ -290,7 +300,7 @@ def fail(self, key: str, **kwargs):
290300
)
291301
raise self.make_error(key=key, **kwargs)
292302

293-
def _validate_missing(self, value):
303+
def _validate_missing(self, value: typing.Any) -> None:
294304
"""Validate missing values. Raise a :exc:`ValidationError` if
295305
`value` should be considered missing.
296306
"""
@@ -357,7 +367,7 @@ def deserialize(
357367

358368
# Methods for concrete classes to override.
359369

360-
def _bind_to_schema(self, field_name, schema):
370+
def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
361371
"""Update field with values from its parent schema. Called by
362372
:meth:`Schema._bind_field <marshmallow.Schema._bind_field>`.
363373
@@ -372,7 +382,7 @@ def _bind_to_schema(self, field_name, schema):
372382

373383
def _serialize(
374384
self, value: typing.Any, attr: str | None, obj: typing.Any, **kwargs
375-
):
385+
) -> typing.Any:
376386
"""Serializes ``value`` to a basic Python datatype. Noop by default.
377387
Concrete :class:`Field` classes should implement this method.
378388
@@ -398,7 +408,7 @@ def _deserialize(
398408
attr: str | None,
399409
data: typing.Mapping[str, typing.Any] | None,
400410
**kwargs,
401-
):
411+
) -> typing.Any:
402412
"""Deserialize value. Concrete :class:`Field` classes should implement this method.
403413
404414
:param value: The value to be deserialized.
@@ -416,9 +426,11 @@ def _deserialize(
416426
# Properties
417427

418428
@property
419-
def context(self):
429+
def context(self) -> dict | None:
420430
"""The context dictionary for the parent :class:`Schema`."""
421-
return self.parent.context
431+
if self.parent:
432+
return self.parent.context
433+
return None
422434

423435
# the default and missing properties are provided for compatibility and
424436
# emit warnings when they are accessed and set
@@ -630,12 +642,14 @@ def _serialize(self, nested_obj, attr, obj, **kwargs):
630642
many = schema.many or self.many
631643
return schema.dump(nested_obj, many=many)
632644

633-
def _test_collection(self, value):
645+
def _test_collection(self, value: typing.Any) -> None:
634646
many = self.schema.many or self.many
635647
if many and not utils.is_collection(value):
636648
raise self.make_error("type", input=value, type=value.__class__.__name__)
637649

638-
def _load(self, value, data, partial=None):
650+
def _load(
651+
self, value: typing.Any, partial: bool | types.StrSequenceOrSet | None = None
652+
):
639653
try:
640654
valid_data = self.schema.load(value, unknown=self.unknown, partial=partial)
641655
except ValidationError as error:
@@ -644,7 +658,14 @@ def _load(self, value, data, partial=None):
644658
) from error
645659
return valid_data
646660

647-
def _deserialize(self, value, attr, data, partial=None, **kwargs):
661+
def _deserialize(
662+
self,
663+
value: typing.Any,
664+
attr: str | None,
665+
data: typing.Mapping[str, typing.Any] | None = None,
666+
partial: bool | types.StrSequenceOrSet | None = None,
667+
**kwargs,
668+
) -> typing.Any:
648669
"""Same as :meth:`Field._deserialize` with additional ``partial`` argument.
649670
650671
:param bool|tuple partial: For nested schemas, the ``partial``
@@ -654,7 +675,7 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs):
654675
Add ``partial`` parameter.
655676
"""
656677
self._test_collection(value)
657-
return self._load(value, data, partial=partial)
678+
return self._load(value, partial=partial)
658679

659680

660681
class Pluck(Nested):
@@ -694,7 +715,7 @@ def __init__(
694715
self.field_name = field_name
695716

696717
@property
697-
def _field_data_key(self):
718+
def _field_data_key(self) -> str:
698719
only_field = self.schema.fields[self.field_name]
699720
return only_field.data_key or self.field_name
700721

@@ -712,7 +733,7 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs):
712733
value = [{self._field_data_key: v} for v in value]
713734
else:
714735
value = {self._field_data_key: value}
715-
return self._load(value, data, partial=partial)
736+
return self._load(value, partial=partial)
716737

717738

718739
class List(Field):
@@ -746,7 +767,7 @@ def __init__(self, cls_or_instance: Field | type[Field], **kwargs):
746767
self.only = self.inner.only
747768
self.exclude = self.inner.exclude
748769

749-
def _bind_to_schema(self, field_name, schema):
770+
def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
750771
super()._bind_to_schema(field_name, schema)
751772
self.inner = copy.deepcopy(self.inner)
752773
self.inner._bind_to_schema(field_name, self)
@@ -790,7 +811,7 @@ class Tuple(Field):
790811
`typing.NamedTuple`, using a Schema within a Nested field for them is
791812
more appropriate than using a `Tuple` field.
792813
793-
:param Iterable[Field] tuple_fields: An iterable of field classes or
814+
:param tuple_fields: An iterable of field classes or
794815
instances.
795816
:param kwargs: The same keyword arguments that :class:`Field` receives.
796817
@@ -800,7 +821,7 @@ class Tuple(Field):
800821
#: Default error messages.
801822
default_error_messages = {"invalid": "Not a valid tuple."}
802823

803-
def __init__(self, tuple_fields, *args, **kwargs):
824+
def __init__(self, tuple_fields: typing.Iterable[Field], *args, **kwargs):
804825
super().__init__(*args, **kwargs)
805826
if not utils.is_collection(tuple_fields):
806827
raise ValueError(
@@ -820,7 +841,7 @@ def __init__(self, tuple_fields, *args, **kwargs):
820841

821842
self.validate_length = Length(equal=len(self.tuple_fields))
822843

823-
def _bind_to_schema(self, field_name, schema):
844+
def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
824845
super()._bind_to_schema(field_name, schema)
825846
new_tuple_fields = []
826847
for field in self.tuple_fields:
@@ -910,7 +931,10 @@ def _deserialize(self, value, attr, data, **kwargs) -> uuid.UUID | None:
910931
return self._validated(value)
911932

912933

913-
class Number(Field):
934+
_NumType = typing.TypeVar("_NumType")
935+
936+
937+
class Number(Field, typing.Generic[_NumType]):
914938
"""Base class for number fields.
915939
916940
:param bool as_string: If `True`, format the serialized value as a string.
@@ -929,14 +953,12 @@ def __init__(self, *, as_string: bool = False, **kwargs):
929953
self.as_string = as_string
930954
super().__init__(**kwargs)
931955

932-
def _format_num(self, value) -> typing.Any:
956+
def _format_num(self, value) -> _NumType:
933957
"""Return the number value for value, given this field's `num_type`."""
934958
return self.num_type(value)
935959

936-
def _validated(self, value) -> _T | None:
960+
def _validated(self, value: typing.Any) -> _NumType:
937961
"""Format the value or raise a :exc:`ValidationError` if an error occurs."""
938-
if value is None:
939-
return None
940962
# (value is True or value is False) is ~5x faster than isinstance(value, bool)
941963
if value is True 10000 or value is False:
942964
raise self.make_error("invalid", input=value)
@@ -947,21 +969,21 @@ def _validated(self, value) -> _T | None:
947969
except OverflowError as error:
948970
raise self.make_error("too_large", input=value) from error
949971

950-
def _to_string(self, value) -> str:
972+
def _to_string(self, value: _NumType) -> str:
951973
return str(value)
952974

953-
def _serialize(self, value, attr, obj, **kwargs) -> str | _T | None:
975+
def _serialize(self, value, attr, obj, **kwargs) -> str | _NumType | None:
954976
"""Return a string if `self.as_string=True`, otherwise return this field's `num_type`."""
955977
if value is None:
956978
return None
957-
ret: _T = self._format_num(value)
979+
ret: _NumType = self._format_num(value)
958980
return self._to_string(ret) if self.as_string else ret
959981

960-
def _deserialize(self, value, attr, data, **kwargs) -> _T | None:
982+
def _deserialize(self, value, attr, data, **kwargs) -> _NumType | None:
961983
return self._validated(value)
962984

963985

964-
class Integer(Number):
986+
class Integer(Number[int]):
965987
"""An integer field.
966988
967989
:param strict: If `True`, only integer types are valid.
@@ -979,13 +1001,13 @@ def __init__(self, *, strict: bool = False, **kwargs):
9791001
super().__init__(**kwargs)
9801002

9811003
# override Number
982-
def _validated(self, value):
1004+
def _validated(self, value: typing.Any) -> int:
9831005
if self.strict and not isinstance(value, numbers.Integral):
9841006
raise self.make_error("invalid", input=value)
9851007
return super()._validated(value)
9861008

9871009

988-
class Float(Number):
1010+
class Float(Number[float]):
9891011
"""A double as an IEEE-754 double precision string.
9901012
9911013
:param bool allow_nan: If `True`, `NaN`, `Infinity` and `-Infinity` are allowed,
@@ -1005,15 +1027,15 @@ def __init__(self, *, allow_nan: bool = False, as_string: bool = False, **kwargs
10051027
self.allow_nan = allow_nan
10061028
super().__init__(as_string=as_string, **kwargs)
10071029

1008-
def _validated(self, value):
1030+
def _validated(self, value: typing.Any) -> float:
10091031
num = super()._validated(value)
10101032
if self.allow_nan is False:
10111033
if math.isnan(num) or num == float("inf") or num == float("-inf"):
10121034
raise self.make_error("special")
10131035
return num
10141036

10151037

1016-
class Decimal(Number):
1038+
class Decimal(Number[decimal.Decimal]):
10171039
"""A field that (de)serializes to the Python ``decimal.Decimal`` type.
10181040
It's safe to use when dealing with money values, percentages, ratios
10191041
or other numbers where precision is critical.
@@ -1084,7 +1106,7 @@ def _format_num(self, value):
10841106
return num
10851107

10861108
# override Number
1087-
def _validated(self, value):
1109+
def _validated(self, value: typing.Any) -> decimal.Decimal:
10881110
try:
10891111
num = super()._validated(value)
10901112
except decimal.InvalidOperation as error:
@@ -1094,7 +1116,7 @@ def _validated(self, value):
10941116
return num
10951117

10961118
# override Number
1097-
def _to_string(self, value):
1119+
def _to_string(self, value: decimal.Decimal) -> str:
10981120
return format(value, "f")
10991121

11001122

@@ -1168,7 +1190,9 @@ def __init__(
11681190
if falsy is not None:
11691191
self.falsy = set(falsy)
11701192

1171-
def _serialize(self, value, attr, obj, **kwargs):
1193+
def _serialize(
1194+
self, value: typing.Any, attr: str | None, obj: typing.Any, **kwargs
1195+
):
11721196
if value is None:
11731197
return None
11741198

0 commit comments

Comments
 (0)
0