8000 bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. by serhiy-storchaka · Pull Request #14952 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. #14952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions Lib/asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,24 @@ def __hash__(self):
return hash(self._when)

def __lt__(self, other):
return self._when < other._when
if isinstance(other, TimerHandle):
return self._when < other._when
return NotImplemented

def __le__(self, other):
if self._when < other._when:
return True
return self.__eq__(other)
if isinstance(other, TimerHandle):
return self._when < other._when or self.__eq__(other)
return NotImplemented

def __gt__(self, other):
return self._when > other._when
if isinstance(other, TimerHandle):
return self._when > other._when
return NotImplemented

def __ge__(self, other):
if self._when > other._when:
return True
return self.__eq__(other)
if isinstance(other, TimerHandle):
return self._when > other._when or self.__eq__(other)
return NotImplemented

def __eq__(self, other):
if isinstance(other, TimerHandle):
Expand All @@ -142,10 +146,6 @@ def __eq__(self, other):
self._cancelled == other._cancelled)
return NotImplemented

def __ne__(self, other):
equal = self.__eq__(other)
return NotImplemented if equal is NotImplemented else not equal

def cancel(self):
if not self._cancelled:
self._loop._timer_handle_cancelled(self)
Expand Down
16 changes: 16 additions & 0 deletions Lib/distutils/tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def test_cmp_strict(self):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))


def test_cmp(self):
Expand All @@ -63,6 +71,14 @@ def test_cmp(self):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))

def test_suite():
return unittest.makeSuite(VersionTestCase)
Expand Down
4 changes: 4 additions & 0 deletions Lib/distutils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def __str__ (self):
def _cmp (self, other):
if isinstance(other, str):
other = StrictVersion(other)
elif not isinstance(other, StrictVersion):
return NotImplemented

if self.version != other.version:
# numeric versions don't match
Expand Down Expand Up @@ -331,6 +333,8 @@ def __repr__ (self):
def _cmp (self, other):
if isinstance(other, str):
other = LooseVersion(other)
elif not isinstance(other, LooseVersion):
return NotImplemented

if self.version == other.version:
return 0
Expand Down
8 changes: 4 additions & 4 deletions Lib/email/headerregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __str__(self):
return self.addr_spec

def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Address):
return NotImplemented
return (self.display_name == other.display_name and
self.username == other.username and
self.domain == other.domain)
Expand Down Expand Up @@ -150,8 +150,8 @@ def __str__(self):
return "{}:{};".format(disp, adrstr)

def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Group):
return NotImplemented
return (self.display_name == other.display_name and
self.addresses == other.addresses)

Expand Down
2 changes: 1 addition & 1 deletion Lib/importlib/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def __eq__(self, other):
self.cached == other.cached and
self.has_location == other.has_location)
except AttributeError:
return False
return NotImplemented

@property
def cached(self):
Expand Down
23 changes: 23 additions & 0 deletions Lib/test/test_asyncio/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from asyncio import selector_events
from test.test_asyncio import utils as test_utils
from test import support
from test.support import ALWAYS_EQ, LARGEST, SMALLEST


def tearDownModule():
Expand Down Expand Up @@ -2364,6 +2365,28 @@ def callback(*args):
self.assertIs(NotImplemented, h1.__eq__(h3))
self.assertIs(NotImplemented, h1.__ne__(h3))

with self.assertRaises(TypeError):
h1 < ()
with self.assertRaises(TypeError):
h1 > ()
with self.assertRaises(TypeError):
h1 <= ()
with self.assertRaises(TypeError):
h1 >= ()
self.assertFalse(h1 == ())
self.assertTrue(h1 != ())

self.assertTrue(h1 == ALWAYS_EQ)
self.assertFalse(h1 != ALWAYS_EQ)
self.assertTrue(h1 < LARGEST)
self.assertFalse(h1 > LARGEST)
self.assertTrue(h1 <= LARGEST)
self.assertFalse(h1 >= LARGEST)
self.assertFalse(h1 < SMALLEST)
self.assertTrue(h1 > SMALLEST)
self.assertFalse(h1 <= SMALLEST)
self.assertTrue(h1 >= SMALLEST)


class AbstractEventLoopTests(unittest.TestCase):

Expand Down
19 changes: 19 additions & 0 deletions Lib/test/test_email/test_headerregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from test.test_email import TestEmailBase, parameterize
from email import headerregistry
from email.headerregistry import Address, Group
from test.support import ALWAYS_EQ


DITTO = object()
Expand Down Expand Up @@ -1525,6 +1526,24 @@ def test_set_message_header_from_group(self):
self.assertEqual(m['to'], 'foo bar:;')
self.assertEqual(m['to'].addresses, g.addresses)

def test_address_comparison(self):
a = Address('foo', 'bar', 'example.com')
self.assertEqual(Address('foo', 'bar', 'example.com'), a)
self.assertNotEqual(Address('baz', 'bar', 'example.com'), a)
self.assertNotEqual(Address('foo', 'baz', 'example.com'), a)
self.assertNotEqual(Address('foo', 'bar', 'baz'), a)
self.assertFalse(a == object())
self.assertTrue(a == ALWAYS_EQ)

def test_group_comparison(self):
a = Address('foo', 'bar', 'example.com')
g = Group('foo bar', [a])
self.assertEqual(Group('foo bar', (a,)), g)
self.assertNotEqual(Group('baz', [a]), g)
self.assertNotEqual(Group('foo bar', []), g)
self.assertFalse(g == object())
self.assertTrue(g == ALWAYS_EQ)


class TestFolding(TestHeaderBase):

Expand Down
16 changes: 15 additions & 1 deletion Lib/test/test_traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest
import re
from test import support
from test.support import TESTFN, Error, captured_output, unlink, cpython_only
from test.support import TESTFN, Error, captured_output, unlink, cpython_only, ALWAYS_EQ
from test.support.script_helper import assert_python_ok
import textwrap

Expand Down Expand Up @@ -887,6 +887,8 @@ def test_basics(self):
# operator fallbacks to FrameSummary.__eq__.
self.assertEqual(tuple(f), f)
self.assertIsNone(f.locals)
self.assertNotEqual(f, object())
self.assertEqual(f, ALWAYS_EQ)

def test_lazy_lines(self):
linecache.clearcache()
Expand Down Expand Up @@ -1083,6 +1085,18 @@ def test_context(self):
self.assertEqual(exc_info[0], exc.exc_type)
self.assertEqual(str(exc_info[1]), str(exc))

def test_comparison(self):
try:
1/0
except Exception:
exc_info = sys.exc_info()
exc = traceback.TracebackException(*exc_info)
exc2 = traceback.TracebackException(*exc_info)
self.assertIsNot(exc, exc2)
self.assertEqual(exc, exc2)
self.assertNotEqual(exc, object())
self.assertEqual(exc, ALWAYS_EQ)

def test_unhashable(self):
class UnhashableException(Exception):
def __eq__(self, other):
Expand Down
9 changes: 8 additions & 1 deletion Lib/test/test_weakref.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import random

from test import support
from test.support import script_helper
from test.support import script_helper, ALWAYS_EQ

# Used in ReferencesTestCase.test_ref_created_during_del() .
ref_from_del = None
Expand Down Expand Up @@ -794,6 +794,10 @@ def test_equality(self):
self.assertTrue(a != c)
self.assertTrue(a == d)
self.assertFalse(a != d)
self.assertFalse(a == x)
self.assertTrue(a != x)
self.assertTrue(a == ALWAYS_EQ)
self.assertFalse(a != ALWAYS_EQ)
del x, y, z
gc.collect()
for r in a, b, c:
Expand Down Expand Up @@ -1102,6 +1106,9 @@ def _ne(a, b):
_ne(a, f)
_ne(b, e)
_ne(b, f)
# Compare with different types
_ne(a, x.some_method)
_eq(a, ALWAYS_EQ)
del x, y, z
gc.collect()
# Dead WeakMethods compare by identity
Expand Down
25 changes: 17 additions & 8 deletions Lib/test/test_xmlrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io
import contextlib
from test import support
from test.support import ALWAYS_EQ, LARGEST, SMALLEST

try:
import gzip
Expand Down Expand Up @@ -530,14 +531,10 @@ def test_comparison(self):
# some other types
dbytes = dstr.encode('ascii')
dtuple = now.timetuple()
with self.assertRaises(TypeError):
dtime == 1970
with self.assertRaises(TypeError):
dtime != dbytes
with self.assertRaises(TypeError):
dtime == bytearray(dbytes)
with self.assertRaises(TypeError):
dtime != dtuple
self.assertFalse(dtime == 1970)
self.assertTrue(dtime != dbytes)
self.assertFalse(dtime == bytearray(dbytes))
self.assertTrue(dtime != dtuple)
with self.assertRaises(TypeError):
dtime < float(1970)
with self.assertRaises(TypeError):
Expand All @@ -547,6 +544,18 @@ def test_comparison(self):
with self.assertRaises(TypeError):
dtime >= dtuple

self.assertTrue(dtime == ALWAYS_EQ)
self.assertFalse(dtime != ALWAYS_EQ)
self.assertTrue(dtime < LARGEST)
self.assertFalse(dtime > LARGEST)
self.assertTrue(dtime <= LARGEST)
self.assertFalse(dtime >= LARGEST)
self.assertFalse(dtime < SMALLEST)
self.assertTrue(dtime > SMALLEST)
self.assertFalse(dtime <= SMALLEST)
self.assertTrue(dtime >= SMALLEST)


class BinaryTestCase(unittest.TestCase):

# XXX What should str(Binary(b"\xff")) return? I'm chosing "\xff"
Expand Down
2 changes: 2 additions & 0 deletions Lib/tkinter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def __eq__(self, other):
Note: if the Variable's master matters to behavior
also compare self._master == other._master
"""
if not isinstance(other, Variable):
return NotImplemented
return self.__class__.__name__ == other.__class__.__name__ \
and self._name == other._name

Expand Down
4 changes: 3 additions & 1 deletion Lib/tkinter/font.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __str__(self):
return self.name

def __eq__(self, other):
return isinstance(other, Font) and self.name == other.name
if not isinstance(other, Font):
return NotImplemented
return self.name == other.name

def __getitem__(self, key):
return self.cget(key)
Expand Down
3 changes: 2 additions & 1 deletion Lib/tkinter/test/test_tkinter/test_font.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import tkinter
from tkinter import font
from test.support import requires, run_unittest, gc_collect
from test.support import requires, run_unittest, gc_collect, ALWAYS_EQ
from tkinter.test.support import AbstractTkTest

requires('gui')
Expand Down Expand Up @@ -70,6 +70,7 @@ def test_eq(self):
self.assertEqual(font1, font2)
self.assertNotEqual(font1, font1.copy())
self.assertNotEqual(font1, 0)
self.assertEqual(font1, ALWAYS_EQ)

def test_measure(self):
self.assertIsInstance(self.font.measure('abc'), int)
Expand Down
13 changes: 10 additions & 3 deletions Lib/tkinter/test/test_tkinter/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
from tkinter import (Variable, StringVar, IntVar, DoubleVar, BooleanVar, Tcl,
TclError)
from test.support import ALWAYS_EQ


class Var(Variable):
Expand Down Expand Up @@ -59,11 +60,17 @@ def test___eq__(self):
# values doesn't matter, only class and name are checked
v1 = Variable(self.root, name="abc")
v2 = Variable(self.root, name="abc")
self.assertIsNot(v1, v2)
self.assertEqual(v1, v2)

v3 = Variable(self.root, name="abc")
v4 = StringVar(self.root, name="abc")
self.assertNotEqual(v3, v4)
v3 = StringVar(self.root, name="abc")
self.assertNotEqual(v1, v3)

V = type('Variable', (), {})
self.assertNotEqual(v1, V())

self.assertNotEqual(v1, object())
self.assertEqual(v1, ALWAYS_EQ)

def test_invalid_name(self):
with self.assertRaises(TypeError):
Expand Down
4 changes: 3 additions & 1 deletion Lib/traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,9 @@ def _load_lines(self):
self.__cause__._load_lines()

def __eq__(self, other):
return self.__dict__ == other.__dict__
if isinstance(other, TracebackException):
return self.__dict__ == other.__dict__
return NotImplemented

def __str__(self):
return self._str
Expand Down
Loading
0