8000 bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. … · python/cpython@662db12 · GitHub
[go: up one dir, main page]

Skip to content

Commit 662db12

Browse files
bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. (GH-14952)
They now return NotImplemented for unsupported type of the other operand.
1 parent 4c69be2 commit 662db12

File tree

23 files changed

+1292
-1147
lines changed

23 files changed

+1292
-1147
lines changed

Lib/asyncio/events.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,20 +119,24 @@ def __hash__(self):
119119
return hash(self._when)
120120

121121
def __lt__(self, other):
122-
return self._when < other._when
122+
if isinstance(other, TimerHandle):
123+
return self._when < other._when
124+
return NotImplemented
123125

124126
def __le__(self, other):
125-
if self._when < other._when:
126-
return True
127-
return self.__eq__(other)
127+
if isinstance(other, TimerHandle):
128+
return self._when < other._when or self.__eq__(other)
129+
return NotImplemented
128130

129131
def __gt__(self, other):
130-
return self._when > other._when
132+
if isinstance(other, TimerHandle):
133+
return self._when > other._when
134+
return NotImplemented
131135

132136
def __ge__(self, other):
133-
if self._when > other._when:
134-
return True
135-
return self.__eq__(other)
137+
if isinstance(other, TimerHandle):
138+
return self._when > other._when or self.__eq__(other)
139+
return NotImplemented
136140

137141
def __eq__(self, other):
138142
if isinstance(other, TimerHandle):
@@ -142,10 +146,6 @@ def __eq__(self, other):
142146
self._cancelled == other._cancelled)
143147
return NotImplemented
144148

145-
def __ne__(self, other):
146-
equal = self.__eq__(other)
147-
return NotImplemented if equal is NotImplemented else not equal
148-
149149
def cancel(self):
150150
if not self._cancelled:
151151
self._loop._timer_handle_cancelled(self)

Lib/distutils/tests/test_version.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def test_cmp_strict(self):
4545
self.assertEqual(res, wanted,
4646
'cmp(%s, %s) should be %s, got %s' %
4747
(v1, v2, wanted, res))
48+
res = StrictVersion(v1)._cmp(v2)
49+
self.assertEqual(res, wanted,
50+
'cmp(%s, %s) should be %s, got %s' %
51+
(v1, v2, wanted, res))
52+
res = StrictVersion(v1)._cmp(object())
53+
self.assertIs(res, NotImplemented,
54+
'cmp(%s, %s) should be NotImplemented, got %s' %
55+
(v1, v2, res))
4856

4957

5058
def test_cmp(self):
@@ -63,6 +71,14 @@ def test_cmp(self):
6371
self.assertEqual(res, wanted,
6472
'cmp(%s, %s) should be %s, got %s' %
6573
(v1, v2, wanted, res))
74+
res = LooseVersion(v1)._cmp(v2)
75+
self.assertEqual(res, wanted,
76+
'cmp(%s, %s) should be %s, got %s' %
77+
(v1, v2, wanted, res))
78+
res = LooseVersion(v1)._cmp(object())
79+
self.assertIs(res, NotImplemented,
80+
'cmp(%s, %s) should be NotImplemented, got %s' %
81+
(v1, v2, res))
6682

6783
def test_suite():
6884
return unittest.makeSuite(VersionTestCase)

Lib/distutils/version.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def __str__ (self):
166166
def _cmp (self, other):
167167
if isinstance(other, str):
168168
other = StrictVersion(other)
169+
elif not isinstance(other, StrictVersion):
170+
return NotImplemented
169171

170172
if self.version != other.version:
171173
# numeric versions don't match
@@ -331,6 +333,8 @@ def __repr__ (self):
331333
def _cmp (self, other):
332334
if isinstance(other, str):
333335
other = LooseVersion(other)
336+
elif not isinstance(other, LooseVersion):
337+
return NotImplemented
334338

335339
if self.version == other.version:
336340
return 0

Lib/email/headerregistry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def __str__(self):
9797
return self.addr_spec
9898

9999
def __eq__(self, other):
100-
if type(other) != type(self):
101-
return False
100+
if not isinstance(other, Address):
101+
return NotImplemented
102102
return (self.display_name == other.display_name and
103103
self.username == other.username and
104104
self.domain == other.domain)
@@ -150,8 +150,8 @@ def __str__(self):
150150
return "{}:{};".format(disp, adrstr)
151151

152152
def __eq__(self, other):
153-
if type(other) != type(self):
154-
return False
153+
if not isinstance(other, Group):
154+
return NotImplemented
155155
return (self.display_name == other.display_name and
156156
self.addresses == other.addresses)
157157

Lib/importlib/_bootstrap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def __eq__(self, other):
371371
self.cached == other.cached and
372372
self.has_location == other.has_location)
373373
except AttributeError:
374-
return False
374+
return NotImplemented
375375

376376
@property
377377
def cached(self):

Lib/test/test_asyncio/test_events.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from asyncio import selector_events
3333
from test.test_asyncio import utils as test_utils
3434
from test import support
35+
from test.support import ALWAYS_EQ, LARGEST, SMALLEST
3536

3637

3738
def tearDownModule():
@@ -2364,6 +2365,28 @@ def callback(*args):
23642365
self.assertIs(NotImplemented, h1.__eq__(h3))
23652366
self.assertIs(NotImplemented, h1.__ne__(h3))
23662367

2368+
with self.assertRaises(TypeError):
2369+
h1 < ()
2370+
with self.assertRaises(TypeError):
2371+
h1 > ()
2372+
with self.assertRaises(TypeError):
2373+
h1 <= ()
2374+
with self.assertRaises(TypeError):
2375+
h1 >= ()
2376+
self.assertFalse(h1 == ())
2377+
self.assertTrue(h1 != ())
2378+
2379+
self.assertTrue(h1 == ALWAYS_EQ)
2380+
self.assertFalse(h1 != ALWAYS_EQ)
2381+
self.assertTrue(h1 < LARGEST)
2382+
self.assertFalse(h1 > LARGEST)
2383+
self.assertTrue(h1 <= LARGEST)
2384+
self.assertFalse(h1 >= LARGEST)
2385+
self.assertFalse(h1 < SMALLEST)
2386+
self.assertTrue(h1 > SMALLEST)
2387+
self.assertFalse(h1 <= SMALLEST)
2388+
self.assertTrue(h1 >= SMALLEST)
2389+
23672390

23682391
class AbstractEventLoopTests(unittest.TestCase):
23692392

Lib/test/test_email/test_headerregistry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from test.test_email import TestEmailBase, parameterize
88
from email import headerregistry
99
from email.headerregistry import Address, Group
10+
from test.support import ALWAYS_EQ
1011

1112

1213
DITTO = object()
@@ -1525,6 +1526,24 @@ def test_set_message_header_from_group(self):
15251526
self.assertEqual(m['to'], 'foo bar:;')
15261527
self.assertEqual(m['to'].addresses, g.addresses)
15271528

1529+
def test_address_comparison(self):
1530+
a = Address('foo', 'bar', 'example.com')
1531+
self.assertEqual(Address('foo', 'bar', 'example.com'), a)
1532+
self.assertNotEqual(Address('baz', 'bar', 'example.com'), a)
1533+
self.assertNotEqual(Address('foo', 'baz', 'example.com'), a)
1534+
self.assertNotEqual(Address('foo', 'bar', 'baz'), a)
1535+
self.assertFalse(a == object())
1536+
self.assertTrue(a == ALWAYS_EQ)
1537+
1538+
def test_group_comparison(self):
1539+
a = Address('foo', 'bar', 'example.com')
1540+
g = Group('foo bar', [a])
1541+
self.assertEqual(Group('foo bar', (a,)), g)
1542+
self.assertNotEqual(Group('baz', [a]), g)
1543+
self.assertNotEqual(Group('foo bar', []), g)
1544+
self.assertFalse(g == object())
1545+
self.assertTrue(g == ALWAYS_EQ)
1546+
15281547

15291548
class TestFolding(TestHeaderBase):
15301549

Lib/test/test_traceback.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import unittest
88
import re
99
from test import support
10-
from test.support import TESTFN, Error, captured_output, unlink, cpython_only
10+
from test.support import TESTFN, Error, captured_output, unlink, cpython_only, ALWAYS_EQ
1111
from test.support.script_helper import assert_python_ok
1212
import textwrap
1313

@@ -887,6 +887,8 @@ def test_basics(self):
887887
# operator fallbacks to FrameSummary.__eq__.
888888
self.assertEqual(tuple(f), f)
889889
self.assertIsNone(f.locals)
890+
self.assertNotEqual(f, object())
891+
self.assertEqual(f, ALWAYS_EQ)
890892

891893
def test_lazy_lines(self):
892894
linecache.clearcache()
@@ -1083,6 +1085,18 @@ def test_context(self):
10831085
self.assertEqual(exc_info[0], exc.exc_type)
10841086
self.assertEqual(str(exc_info[1]), str(exc))
10851087

1088+
def test_comparison(self):
1089+
try:
1090+
1/0
1091+
except Exception:
1092+
exc_info = sys.exc_info()
1093+
exc = traceback.TracebackException(*exc_info)
1094+
exc2 = traceback.TracebackException(*exc_info)
1095+
self.assertIsNot(exc, exc2)
1096+
self.assertEqual(exc, exc2)
1097+
self.assertNotEqual(exc, object())
1098+
self.assertEqual(exc, ALWAYS_EQ)
1099+
10861100
def test_unhashable(self):
10871101
class UnhashableException(Exception):
10881102
def __eq__(self, other):

Lib/test/test_weakref.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import random
1212

1313
from test import support
14-
from test.support import script_helper
14+
from test.support import script_helper, ALWAYS_EQ
1515

1616
# Used in ReferencesTestCase.test_ref_created_during_del() .
1717
ref_from_del = None
@@ -794,6 +794,10 @@ def test_equality(self):
794794
self.assertTrue(a != c)
795795
self.assertTrue(a == d)
796796
self.assertFalse(a != d)
797+
self.assertFalse(a == x)
798+
self.assertTrue(a != x)
799+
self.assertTrue(a == ALWAYS_EQ)
800+
self.assertFalse(a != ALWAYS_EQ)
797801
del x, y, z
798802
gc.collect()
799803
for r in a, b, c:
@@ -1102,6 +1106,9 @@ def _ne(a, b):
11021106
_ne(a, f)
11031107
_ne(b, e)
11041108
_ne(b, f)
1109+
# Compare with different types
1110+
_ne(a, x.some_method)
1111+
_eq(a, ALWAYS_EQ)
11051112
del x, y, z
11061113
gc.collect()
11071114
# Dead WeakMethods compare by identity

Lib/test/test_xmlrpc.py

Lines changed: 17 additions & 8 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import io
1616
import contextlib
1717
from test import support
18+
from test.support import ALWAYS_EQ, LARGEST, SMALLEST
1819

1920
try:
2021
import gzip
@@ -530,14 +531,10 @@ def test_comparison(self):
530531
# some other types
531532
dbytes = dstr.encode('ascii')
532533
dtuple = now.timetuple()
533-
with self.assertRaises(TypeError):
534-
dtime == 1970
535-
with self.assertRaises(TypeError):
536-
dtime != dbytes
537-
with self.assertRaises(TypeError):
538-
dtime == bytearray(dbytes)
539-
with self.assertRaises(TypeError):
540-
dtime != dtuple
534+
self.assertFalse(dtime == 1970)
535+
self.assertTrue(dtime != dbytes)
536+
self.assertFalse(dtime == bytearray(dbytes))
537+
self.assertTrue(dtime != dtuple)
541538
with self.assertRaises(TypeError):
542539
dtime < float(1970)
543540
with self.assertRaises(TypeError):
@@ -547,6 +544,18 @@ def test_comparison(self):
547544
with self.assertRaises(TypeError):
548545
dtime >= dtuple
549546

547+
self.assertTrue(dtime == ALWAYS_EQ)
548+
self.assertFalse(dtime != ALWAYS_EQ)
549+
self.assertTrue(dtime < LARGEST)
550+
self.assertFalse(dtime > LARGEST)
551+
self.assertTrue(dtime <= LARGEST)
552+
self.assertFalse(dtime >= LARGEST)
553+
self.assertFalse(dtime < SMALLEST)
554+
self.assertTrue(dtime > SMALLEST)
555+
self.assertFalse(dtime <= SMALLEST)
556+
self.assertTrue(dtime >= SMALLEST)
557+
558+
550559
class BinaryTestCase(unittest.TestCase):
551560

552561
# XXX What should str(Binary(b"\xff")) return? I'm chosing "\xff"

Lib/tkinter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ def __eq__(self, other):
484484
Note: if the Variable's master matters to behavior
485485
also compare self._master == other._master
486486
"""
487+
if not isinstance(other, Variable):
488+
return NotImplemented
487489
return self.__class__.__name__ == other.__class__.__name__ \
488490
and self._name == other._name
489491

Lib/tkinter/font.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def __str__(self):
101101
return self.name
102102

103103
def __eq__(self, other):
104-
return isinstance(other, Font) and self.name == other.name
104+
if not isinstance(other, Font):
105+
return NotImplemented
106+
return self.name == other.name
105107

106108
def __getitem__(self, key):
107109
return self.cget(key)

Lib/tkinter/test/test_tkinter/test_font.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
import tkinter
33
from tkinter import font
4-
from test.support import requires, run_unittest, gc_collect
4+
from test.support import requires, run_unittest, gc_collect, ALWAYS_EQ
55
from tkinter.test.support import AbstractTkTest
66

77
requires('gui')
@@ -70,6 +70,7 @@ def test_eq(self):
7070
self.assertEqual(font1, font2)
7171
self.assertNotEqual(font1, font1.copy())
7272
self.assertNotEqual(font1, 0)
73+
self.assertEqual(font1, ALWAYS_EQ)
7374

7475
def test_measure(self):
7576
self.assertIsInstance(self.font.measure('abc'), int)

Lib/tkinter/test/test_tkinter/test_variables.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import gc
33
from tkinter import (Variable, StringVar, IntVar, DoubleVar, BooleanVar, Tcl,
44
TclError)
5+
from test.support import ALWAYS_EQ
56

67

78
class Var(Variable):
@@ -59,11 +60,17 @@ def test___eq__(self):
5960
# values doesn't matter, only class and name are checked
6061
v1 = Variable(self.root, name="abc")
6162
v2 = Variable(self.root, name="abc")
63+
self.assertIsNot(v1, v2)
6264
self.assertEqual(v1, v2)
6365

64-
v3 = Variable(self.root, name="abc")
65-
v4 = StringVar(self.root, name="abc")
66-
self.assertNotEqual(v3, v4)
66+
v3 = StringVar(self.root, name="abc")
67+
self.assertNotEqual(v1, v3)
68+
69+
V = type('Variable', (), {})
70+
self.assertNotEqual(v1, V())
71+
72+
self.assertNotEqual(v1, object())
73+
self.assertEqual(v1, ALWAYS_EQ)
6774

6875
def test_invalid_name(self):
6976
with self.assertRaises(TypeError):

Lib/traceback.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,9 @@ def _load_lines(self):
538538
self.__cause__._load_lines()
539539

540540
def __eq__(self, other):
541-
return self.__dict__ == other.__dict__
541+
if isinstance(other, TracebackException):
542+
return self.__dict__ == other.__dict__
543+
return NotImplemented
542544

543545
def __str__(self):
544546
return self._str

0 commit comments

Comments
 (0)
0