8000 bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. by serhiy-storchaka · Pull Request #14952 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. #14952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes.
They now return NotImplemented for unsupported type of the other operand.
  • Loading branch information
serhiy-storchaka committed Jul 25, 2019
commit 4b4b2f1b849e448aabbf406b3092a57df6005bba
16 changes: 8 additions & 8 deletions Lib/asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,22 @@ def __hash__(self):
return hash(self._when)

def __lt__(self, other):
return self._when < other._when
if isinstance(other, TimerHandle):
return self._when < other._when
return NotImplemented

def __le__(self, other):
if self._when < other._when:
if isinstance(other, TimerHandle) and self._when < other._when:
return True
return self.__eq__(other)

def __gt__(self, other):
return self._when > other._when
if isinstance(other, TimerHandle):
return self._when > other._when
return NotImplemented

def __ge__(self, other):
if self._when > other._when:
if isinstance(other, TimerHandle) and self._when > other._when:
return True
return self.__eq__(other)

Expand All @@ -142,10 +146,6 @@ def __eq__(self, other):
self._cancelled == other._cancelled)
return NotImplemented

def __ne__(self, other):
equal = self.__eq__(other)
return NotImplemented if equal is NotImplemented else not equal

def cancel(self):
if not self._cancelled:
self._loop._timer_handle_cancelled(self)
Expand Down
22 changes: 11 additions & 11 deletions Lib/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,25 +739,25 @@ def __le__(self, other):
if isinstance(other, timedelta):
return self._cmp(other) <= 0
else:
_cmperror(self, other)
return NotImplemented
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, we recently changed these and I guess I just assumed that _cmperror returned NotImplemented.

I agree with this change, but I would argue that it precludes backporting to 3.7 (and possibly 3.8), because it is a pretty significant change to the semantics of the language.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change fixes a discrepancy between Python and C implementations. Other changes are discussable, but this change is a bugfix and it should be backported.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if that is the case, I agree about backporting the subset of these changes that fixes discrepancies between the Python and C implementations. We should figure out what that subset is and break it into two PRs - one for backporting and one for master (and maybe 3.8).


def __lt__(self, other):
if isinstance(other, timedelta):
return self._cmp(other) < 0
else:
_cmperror(self, other)
return NotImplemented

def __ge__(self, other):
if isinstance(other, timedelta):
return self._cmp(other) >= 0
else:
_cmperror(self, other)
return NotImplemented

def __gt__(self, other):
if isinstance(other, timedelta):
return self._cmp(other) > 0
else:
_cmperror(self, other)
return NotImplemented

def _cmp(self, other):
assert isinstance(other, timedelta)
Expand Down Expand Up @@ -1316,25 +1316,25 @@ def __le__(self, other):
if isinstance(other, time):
return self._cmp(other) <= 0
else:
_cmperror(self, other)
return NotImplemented

def __lt__(self, other):
if isinstance(other, time):
return self._cmp(other) < 0
else:
_cmperror(self, other)
return NotImplemented

def __ge__(self, other):
if isinstance(other, time):
return self._cmp(other) >= 0
else:
_cmperror(self, other)
return NotImplemented

def __gt__(self, other):
if isinstance(other, time):
return self._cmp(other) > 0
else:
_cmperror(self, other)
return NotImplemented

def _cmp(self, other, allow_mixed=False):
assert isinstance(other, time)
Expand Down Expand Up @@ -2210,9 +2210,9 @@ def __getinitargs__(self):
return (self._offset, self._name)

def __eq__(self, other):
if type(other) != timezone:
return False
return self._offset == other._offset
if isinstance(other, timezone):
return self._offset == other._offset
return NotImplemented

def __hash__(self):
return hash(self._offset)
Expand Down
16 changes: 16 additions & 0 deletions Lib/distutils/tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def test_cmp_strict(self):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))


def test_cmp(self):
Expand All @@ -63,6 +71,14 @@ def test_cmp(self):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))

def test_suite():
return unittest.makeSuite(VersionTestCase)
Expand Down
4 changes: 4 additions & 0 deletions Lib/distutils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def __str__ (self):
def _cmp (self, other):
if isinstance(other, str):
other = StrictVersion(other)
elif not isinstance(other, StrictVersion):
return NotImplemented

if self.version != other.version:
# numeric versions don't match
Expand Down Expand Up @@ -331,6 +333,8 @@ def __repr__ (self):
def _cmp (self, other):
if isinstance(other, str):
other = LooseVersion(other)
elif not isinstance(other, LooseVersion):
return NotImplemented

if self.version == other.version:
return 0
Expand Down
8 changes: 4 additions & 4 deletions Lib/email/headerregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __str__(self):
return self.addr_spec

def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Address):
return NotImplemented
return (self.display_name == other.display_name and
self.username == other.username and
self.domain == other.domain)
Expand Down Expand Up @@ -150,8 +150,8 @@ def __str__(self):
return "{}:{};".format(disp, adrstr)

def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Group):
return NotImplemented
return (self.display_name == other.display_name and
self.addresses == other.addresses)

Expand Down
2 changes: 1 addition & 1 deletion Lib/importlib/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def __eq__(self, other):
self.cached == other.cached and
self.has_location == other.has_location)
except AttributeError:
return False
return NotImplemented

@property
def cached(self):
Expand Down
58 changes: 29 additions & 29 deletions Lib/test/datetimetester.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import struct
import unittest
from unittest.mock import ANY
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a dedicated class for this in datetimetester.py to avoid relying on the implementation details of ANY.

I think we can move it into tests.support if you are going to be adding tests like this across multiple modules.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we need yet one ANY in private tests.support if we can use public unittest.mock.ANY.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unittest.mock.ANY only accidentally has the property that it compares equal to things. Even if that were a guaranteed property (it probably won't change, but it's an implementation detail), it's still confusing to see people using the "wild card for arguments" singleton from unittest.mock for one of its incidental properties (comparing equal to things).

If someone tried to understand this test and looked it up in the documentation, they would be confused as to what it means. I think it is much clearer to use a dedicated class for this (particularly since we already have one in the datetime tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That class was added 15 days ago. Perhaps the author was not aware of mock.ANY. There are explicit tests for this property of mock.ANY, it is not incidental.

It is not very difficult to add a duplicate in test.support, but I am not sure that it is worth.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are explicit tests for this property of mock.ANY, it is not incidental.

I don't know about that, people test implementation details all the time.

That class was added 15 days ago. Perhaps the author was not aware of mock.ANY.

Definitely they were. The original PR used mock.ANY and I asked for it to be replaced with a dedicated class. You can see my comment on the original PR.

If you look at the documentation for ANY, it is clearly documented as a wild card to be used for making assertions about some calls to a function but not others. Nothing in the documentation says that it is implemented with a permissive __eq__, and in fact a permissive __eq__ is not always sufficient, as we learned in #14700. The ANY class does have the property you want, but it's a dedicated singleton with a semantic meaning other than "this compares equal to anything".

For reasons of clarity, I think it makes sense to have dedicated singletons for each of the custom comparison behaviors, even if there are other objects that also have those properties for their own reasons.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change has been extracted into #14996. Please take a look.


from array import array

Expand Down Expand Up @@ -353,6 +354,18 @@ def test_comparison(self):
self.assertTrue(timezone(ZERO) != None)
self.assertFalse(timezone(ZERO) == None)

tz = timezone(ZERO)
largest = support.LargestObject()
self.assertTrue(tz < largest)
self.assertFalse(tz > largest)
self.assertTrue(tz <= largest)
self.assertFalse(tz >= largest)
self.assertFalse(tz == largest)
self.assertTrue(tz != largest)

self.assertTrue(tz == ANY)
self.assertFalse(tz != ANY)

def test_aware_datetime(self):
# test that timezone instances can be used by datetime
t = datetime(1, 1, 1)
Expand Down Expand Up @@ -414,10 +427,20 @@ def test_harmless_mixed_comparison(self):

# Comparison to objects of unsupported types should return
# NotImplemented which falls back to the right hand side's __eq__
# method. In this case, ComparesEqualClass.__eq__ always returns True.
# ComparesEqualClass.__ne__ always returns False.
self.assertTrue(me == ComparesEqualClass())
self.assertFalse(me != ComparesEqualClass())
# method. In this case, ANY.__eq__ always returns True.
# ANY.__ne__ always returns False.
self.assertTrue(me == ANY)
self.assertFalse(me != ANY)

# If the other class explicitly defines ordering
# relative to our class, it is allowed to do so
largest = support.LargestObject()
self.assertTrue(me < largest)
self.assertFalse(me > largest)
self.assertTrue(me <= largest)
self.assertFalse(me >= largest)
self.assertFalse(me == largest)
self.assertTrue(me != largest)

def test_harmful_mixed_comparison(self):
me = self.theclass(1, 1, 1)
Expand Down Expand Up @@ -1582,29 +1605,6 @@ class SomeClass:
self.assertRaises(TypeError, lambda: our < their)
self.assertRaises(TypeError, lambda: their < our)

# However, if the other class explicitly defines ordering
# relative to our class, it is allowed to do so

class LargerThanAnything:
def __lt__(self, other):
return False
def __le__(self, other):
return isinstance(other, LargerThanAnything)
def __eq__(self, other):
return isinstance(other, LargerThanAnything)
def __gt__(self, other):
return not isinstance(other, LargerThanAnything)
def __ge__(self, other):
return True

their = LargerThanAnything()
self.assertEqual(our == their, False)
self.assertEqual(their == our, False)
self.assertEqual(our != their, True)
self.assertEqual(their != our, True)
self.assertEqual(our < their, True)
self.assertEqual(their < our, False)

def test_bool(self):
# All dates are considered true.
self.assertTrue(self.theclass.min)
Expand Down Expand Up @@ -3781,8 +3781,8 @@ def test_replace(self):
self.assertRaises(ValueError, base.replace, microsecond=1000000)

def test_mixed_compare(self):
t1 = time(1, 2, 3)
t2 = time(1, 2, 3)
t1 = self.theclass(1, 2, 3)
t2 = self.theclass(1, 2, 3)
self.assertEqual(t1, t2)
t2 = t2.replace(tzinfo=None)
self.assertEqual(t1, t2)
Expand Down
15 changes: 15 additions & 0 deletions Lib/test/support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3093,6 +3093,21 @@ def __fspath__(self):
return self.path


@functools.total_ordering
class LargestObject:
Copy link
Member
@pganssle pganssle Jul 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making these singletons?

LargestObject = LargestObject() and SmallestObject = SmallestObject().

I don't think anyone needs to subclass these and there's no need to instantiate more than one of them.

Similarly we can move ComparesEqualClass into this module and make it a singleton.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

def __eq__(self, other):
return isinstance(other, LargestObject)
def __lt__(self, other):
return False

@functools.total_ordering
class SmallestObject:
def __eq__(self, other):
return isinstance(other, SmallestObject)
def __gt__(self, other):
return False


def maybe_get_event_loop_policy():
"""Return the global event loop policy if one is set, else return None."""
return asyncio.events._event_loop_policy
Expand Down
22 changes: 22 additions & 0 deletions Lib/test/test_asyncio/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,28 @@ def callback(*args):
self.assertIs(NotImplemented, h1.__eq__(h3))
self.assertIs(NotImplemented, h1.__ne__(h3))

with self.assertRaises(TypeError):
h1 < ()
with self.assertRaises(TypeError):
h1 > ()
with self.assertRaises(TypeError):
h1 <= ()
with self.assertRaises(TypeError):
h1 >= ()
self.assertFalse(h1 == ())
self.assertTrue(h1 != ())

largest = support.LargestObject()
self.assertTrue(h1 < largest)
self.assertFalse(h1 > largest)
self.assertTrue(h1 <= largest)
self.assertFalse(h1 >= largest)
self.assertFalse(h1 == largest)
self.assertTrue(h1 != largest)

self.assertTrue(h1 == mock.ANY)
self.assertFalse(h1 != mock.ANY)


class AbstractEventLoopTests(unittest.TestCase):

Expand Down
19 changes: 19 additions & 0 deletions Lib/test/test_email/test_headerregistry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import textwrap
import unittest
from unittest.mock import ANY
from email import errors
from email import policy
from email.message import Message
Expand Down Expand Up @@ -1525,6 +1526,24 @@ def test_set_message_header_from_group(self):
self.assertEqual(m['to'], 'foo bar:;')
self.assertEqual(m['to'].addresses, g.addresses)

def test_address_comparison(self):
a = Address('foo', 'bar', 'example.com')
self.assertEqual(Address('foo', 'bar', 'example.com'), a)
self.assertNotEqual(Address('baz', 'bar', 'example.com'), a)
self.assertNotEqual(Address('foo', 'baz', 'example.com'), a)
self.assertNotEqual(Address('foo', 'bar', 'baz'), a)
self.assertFalse(a == object())
self.assertTrue(a == ANY)

def test_group_comparison(self):
a = Address('foo', 'bar', 'example.com')
g = Group('foo bar', [a])
self.assertEqual(Group('foo bar', (a,)), g)
self.assertNotEqual(Group('baz', [a]), g)
self.assertNotEqual(Group('foo bar', []), g)
self.assertFalse(g == object())
self.assertTrue(g == ANY)


class TestFolding(TestHeaderBase):

Expand Down
15 changes: 1 addition & 14 deletions Lib/test/test_ipaddress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pickle
import ipaddress
import weakref
from test.support import LargestObject, SmallestObject


class BaseTestCase(unittest.TestCase):
Expand Down Expand Up @@ -673,20 +674,6 @@ def test_ip_network(self):
self.assertFactoryError(ipaddress.ip_network, "network")


@functools.total_ordering
class LargestObject:
def __eq__(self, other):
return isinstance(other, LargestObject)
def __lt__(self, other):
return False

@functools.total_ordering
class SmallestObject:
def __eq__(self, other):
return isinstance(other, SmallestObject)
def __gt__(self, other):
return False

class ComparisonTests(unittest.TestCase):

v4addr = ipaddress.IPv4Address(1)
Expand Down
Loading
0