8000 Implement `__eq__` for validators (#8925) · encode/django-rest-framework@0d6ef03 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0d6ef03

Browse files
authored
Implement __eq__ for validators (#8925)
* Implement equality operator and add test coverage * Add documentation on implementation
1 parent b1cec51 commit 0d6ef03

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

docs/api-guide/validators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ If we open up the Django shell using `manage.py shell` we can now
5353

5454
The interesting bit here is the `reference` field. We can see that the uniqueness constraint is being explicitly enforced by a validator on the serializer field.
5555

56-
Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below.
56+
Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below. REST framework validators, like their Django counterparts, implement the `__eq__` method, allowing you to compare instances for equality.
5757

5858
---
5959

rest_framework/validators.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ def __repr__(self):
7979
smart_repr(self.queryset)
8080
)
8181

82+
def __eq__(self, other):
83+
if not isinstance(other, self.__class__):
84+
return NotImplemented
85+
return (self.message == other.message
86+
and self.requires_context == other.requires_context
87+
and self.queryset == other.queryset
88+
and self.lookup == other.lookup
89+
)
90+
8291

8392
class UniqueTogetherValidator:
8493
"""
@@ -166,6 +175,16 @@ def __repr__(self):
166175
smart_repr(self.fields)
167176
)
168177

178+
def __eq__(self, other):
179+
if not isinstance(other, self.__class__):
180+
return NotImplemented
181+
return (self.message == other.message
182+
and self.requires_context == other.requires_context
183+
and self.missing_message == other.missing_message
184+
and self.queryset == other.queryset
185+
and self.fields == other.fields
186+
)
187+
169188

170189
class ProhibitSurrogateCharactersValidator:
171190
message = _('Surrogate characters are not allowed: U+{code_point:X}.')
@@ -177,6 +196,13 @@ def __call__(self, value):
177196
message = self.message.format(code_point=ord(surrogate_character))
178197
raise ValidationError(message, code=self.code)
179198

199+
def __eq__(self, other):
200+
if not isinstance(other, self.__class__):
201+
return NotImplemented
202+
return (self.message == other.message
203+
and self.code == other.code
204+
)
205+
180206

181207
class BaseUniqueForValidator:
182208
message = None
@@ -230,6 +256,17 @@ def __call__(self, attrs, serializer):
230256
self.field: message
231257
}, code='unique')
232258

259+
def __eq__(self, other):
260+
if not isinstance(other, self.__class__):
261+
return NotImplemented
262+
return (self.message == other.message
263+
and self.missing_message == other.missing_message
264+
and self.requires_context == other.requires_context
265+
and self.queryset == other.queryset
266+
and self.field == other.field
267+
and self.date_field == other.date_field
268+
)
269+
233270
def __repr__(self):
234271
return '<%s(queryset=%s, field=%s, date_field=%s)>' % (
235272
self.__class__.__name__,

tests/test_validators.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
from unittest.mock import MagicMock
23

34
import pytest
45
from django.db import DataError, models
@@ -787,3 +788,13 @@ def test_validator_raises_error_when_abstract_method_called(self):
787788
validator.filter_queryset(
788789
attrs=None, queryset=None, field_name='', date_field_name=''
789790
)
791+
792+
def test_equality_operator(self):
793+
mock_queryset = MagicMock()
794+
validator = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
795+
date_field='bar')
796+
validator2 = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
797+
date_field='bar')
798+
assert validator == validator2
799+
validator2.date_field = "bar2"
800+
assert validator != validator2

0 commit comments

Comments
 (0)
0