Skip to content

Commit

Permalink
bpo-37685: Fixed __eq__, __lt__ etc implementations in some classes. (p…
Browse files Browse the repository at this point in the history
…ythonGH-14952)

They now return NotImplemented for unsupported type of the other operand.
  • Loading branch information
serhiy-storchaka authored Aug 8, 2019
1 parent 4c69be2 commit 662db12
Show file tree
Hide file tree
Showing 23 changed files with 1,292 additions and 1,147 deletions.
24 changes: 12 additions & 12 deletions Lib/asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,24 @@ def __hash__(self):
return hash(self._when)

def __lt__(self, other):
return self._when < other._when
if isinstance(other, TimerHandle):
return self._when < other._when
return NotImplemented

def __le__(self, other):
if self._when < other._when:
return True
return self.__eq__(other)
if isinstance(other, TimerHandle):
return self._when < other._when or self.__eq__(other)
return NotImplemented

def __gt__(self, other):
return self._when > other._when
if isinstance(other, TimerHandle):
return self._when > other._when
return NotImplemented

def __ge__(self, other):
if self._when > other._when:
return True
return self.__eq__(other)
if isinstance(other, TimerHandle):
return self._when > other._when or self.__eq__(other)
return NotImplemented

def __eq__(self, other):
if isinstance(other, TimerHandle):
Expand All @@ -142,10 +146,6 @@ def __eq__(self, other):
self._cancelled == other._cancelled)
return NotImplemented

def __ne__(self, other):
equal = self.__eq__(other)
return NotImplemented if equal is NotImplemented else not equal

def cancel(self):
if not self._cancelled:
self._loop._timer_handle_cancelled(self)
Expand Down
16 changes: 16 additions & 0 deletions Lib/distutils/tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def test_cmp_strict(self):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = StrictVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))


def test_cmp(self):
Expand All @@ -63,6 +71,14 @@ def test_cmp(self):
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(v2)
self.assertEqual(res, wanted,
'cmp(%s, %s) should be %s, got %s' %
(v1, v2, wanted, res))
res = LooseVersion(v1)._cmp(object())
self.assertIs(res, NotImplemented,
'cmp(%s, %s) should be NotImplemented, got %s' %
(v1, v2, res))

def test_suite():
return unittest.makeSuite(VersionTestCase)
Expand Down
4 changes: 4 additions & 0 deletions Lib/distutils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def __str__ (self):
def _cmp (self, other):
if isinstance(other, str):
other = StrictVersion(other)
elif not isinstance(other, StrictVersion):
return NotImplemented

if self.version != other.version:
# numeric versions don't match
Expand Down Expand Up @@ -331,6 +333,8 @@ def __repr__ (self):
def _cmp (self, other):
if isinstance(other, str):
other = LooseVersion(other)
elif not isinstance(other, LooseVersion):
return NotImplemented

if self.version == other.version:
return 0
Expand Down
8 changes: 4 additions & 4 deletions Lib/email/headerregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __str__(self):
return self.addr_spec

def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Address):
return NotImplemented
return (self.display_name == other.display_name and
self.username == other.username and
self.domain == other.domain)
Expand Down Expand Up @@ -150,8 +150,8 @@ def __str__(self):
return "{}:{};".format(disp, adrstr)

def __eq__(self, other):
if type(other) != type(self):
return False
if not isinstance(other, Group):
return NotImplemented
return (self.display_name == other.display_name and
self.addresses == other.addresses)

Expand Down
2 changes: 1 addition & 1 deletion Lib/importlib/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def __eq__(self, other):
self.cached == other.cached and
self.has_location == other.has_location)
except AttributeError:
return False
return NotImplemented

@property
def cached(self):
Expand Down
23 changes: 23 additions & 0 deletions Lib/test/test_asyncio/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from asyncio import selector_events
from test.test_asyncio import utils as test_utils
from test import support
from test.support import ALWAYS_EQ, LARGEST, SMALLEST


def tearDownModule():
Expand Down Expand Up @@ -2364,6 +2365,28 @@ def callback(*args):
self.assertIs(NotImplemented, h1.__eq__(h3))
self.assertIs(NotImplemented, h1.__ne__(h3))

with self.assertRaises(TypeError):
h1 < ()
with self.assertRaises(TypeError):
h1 > ()
with self.assertRaises(TypeError):
h1 <= ()
with self.assertRaises(TypeError):
h1 >= ()
self.assertFalse(h1 == ())
self.assertTrue(h1 != ())

self.assertTrue(h1 == ALWAYS_EQ)
self.assertFalse(h1 != ALWAYS_EQ)
self.assertTrue(h1 < LARGEST)
self.assertFalse(h1 > LARGEST)
self.assertTrue(h1 <= LARGEST)
self.assertFalse(h1 >= LARGEST)
self.assertFalse(h1 < SMALLEST)
self.assertTrue(h1 > SMALLEST)
self.assertFalse(h1 <= SMALLEST)
self.assertTrue(h1 >= SMALLEST)


class AbstractEventLoopTests(unittest.TestCase):

Expand Down
19 changes: 19 additions & 0 deletions Lib/test/test_email/test_headerregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from test.test_email import TestEmailBase, parameterize
from email import headerregistry
from email.headerregistry import Address, Group
from test.support import ALWAYS_EQ


DITTO = object()
Expand Down Expand Up @@ -1525,6 +1526,24 @@ def test_set_message_header_from_group(self):
self.assertEqual(m['to'], 'foo bar:;')
self.assertEqual(m['to'].addresses, g.addresses)

def test_address_comparison(self):
a = Address('foo', 'bar', 'example.com')
self.assertEqual(Address('foo', 'bar', 'example.com'), a)
self.assertNotEqual(Address('baz', 'bar', 'example.com'), a)
self.assertNotEqual(Address('foo', 'baz', 'example.com'), a)
self.assertNotEqual(Address('foo', 'bar', 'baz'), a)
self.assertFalse(a == object())
self.assertTrue(a == ALWAYS_EQ)

def test_group_comparison(self):
a = Address('foo', 'bar', 'example.com')
g = Group('foo bar', [a])
self.assertEqual(Group('foo bar', (a,)), g)
self.assertNotEqual(Group('baz', [a]), g)
self.assertNotEqual(Group('foo bar', []), g)
self.assertFalse(g == object())
self.assertTrue(g == ALWAYS_EQ)


class TestFolding(TestHeaderBase):

Expand Down
16 changes: 15 additions & 1 deletion Lib/test/test_traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest
import re
from test import support
from test.support import TESTFN, Error, captured_output, unlink, cpython_only
from test.support import TESTFN, Error, captured_output, unlink, cpython_only, ALWAYS_EQ
from test.support.script_helper import assert_python_ok
import textwrap

Expand Down Expand Up @@ -887,6 +887,8 @@ def test_basics(self):
# operator fallbacks to FrameSummary.__eq__.
self.assertEqual(tuple(f), f)
self.assertIsNone(f.locals)
self.assertNotEqual(f, object())
self.assertEqual(f, ALWAYS_EQ)

def test_lazy_lines(self):
linecache.clearcache()
Expand Down Expand Up @@ -1083,6 +1085,18 @@ def test_context(self):
self.assertEqual(exc_info[0], exc.exc_type)
self.assertEqual(str(exc_info[1]), str(exc))

def test_comparison(self):
try:
1/0
except Exception:
exc_info = sys.exc_info()
exc = traceback.TracebackException(*exc_info)
exc2 = traceback.TracebackException(*exc_info)
self.assertIsNot(exc, exc2)
self.assertEqual(exc, exc2)
self.assertNotEqual(exc, object())
self.assertEqual(exc, ALWAYS_EQ)

def test_unhashable(self):
class UnhashableException(Exception):
def __eq__(self, other):
Expand Down
9 changes: 8 additions & 1 deletion Lib/test/test_weakref.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import random

from test import support
from test.support import script_helper
from test.support import script_helper, ALWAYS_EQ

# Used in ReferencesTestCase.test_ref_created_during_del() .
ref_from_del = None
Expand Down Expand Up @@ -794,6 +794,10 @@ def test_equality(self):
self.assertTrue(a != c)
self.assertTrue(a == d)
self.assertFalse(a != d)
self.assertFalse(a == x)
self.assertTrue(a != x)
self.assertTrue(a == ALWAYS_EQ)
self.assertFalse(a != ALWAYS_EQ)
del x, y, z
gc.collect()
for r in a, b, c:
Expand Down Expand Up @@ -1102,6 +1106,9 @@ def _ne(a, b):
_ne(a, f)
_ne(b, e)
_ne(b, f)
# Compare with different types
_ne(a, x.some_method)
_eq(a, ALWAYS_EQ)
del x, y, z
gc.collect()
# Dead WeakMethods compare by identity
Expand Down
25 changes: 17 additions & 8 deletions Lib/test/test_xmlrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io
import contextlib
from test import support
from test.support import ALWAYS_EQ, LARGEST, SMALLEST

try:
import gzip
Expand Down Expand Up @@ -530,14 +531,10 @@ def test_comparison(self):
# some other types
dbytes = dstr.encode('ascii')
dtuple = now.timetuple()
with self.assertRaises(TypeError):
dtime == 1970
with self.assertRaises(TypeError):
dtime != dbytes
with self.assertRaises(TypeError):
dtime == bytearray(dbytes)
with self.assertRaises(TypeError):
dtime != dtuple
self.assertFalse(dtime == 1970)
self.assertTrue(dtime != dbytes)
self.assertFalse(dtime == bytearray(dbytes))
self.assertTrue(dtime != dtuple)
with self.assertRaises(TypeError):
dtime < float(1970)
with self.assertRaises(TypeError):
Expand All @@ -547,6 +544,18 @@ def test_comparison(self):
with self.assertRaises(TypeError):
dtime >= dtuple

self.assertTrue(dtime == ALWAYS_EQ)
self.assertFalse(dtime != ALWAYS_EQ)
self.assertTrue(dtime < LARGEST)
self.assertFalse(dtime > LARGEST)
self.assertTrue(dtime <= LARGEST)
self.assertFalse(dtime >= LARGEST)
self.assertFalse(dtime < SMALLEST)
self.assertTrue(dtime > SMALLEST)
self.assertFalse(dtime <= SMALLEST)
self.assertTrue(dtime >= SMALLEST)


class BinaryTestCase(unittest.TestCase):

# XXX What should str(Binary(b"\xff")) return? I'm chosing "\xff"
Expand Down
2 changes: 2 additions & 0 deletions Lib/tkinter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def __eq__(self, other):
Note: if the Variable's master matters to behavior
also compare self._master == other._master
"""
if not isinstance(other, Variable):
return NotImplemented
return self.__class__.__name__ == other.__class__.__name__ \
and self._name == other._name

Expand Down
4 changes: 3 additions & 1 deletion Lib/tkinter/font.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __str__(self):
return self.name

def __eq__(self, other):
return isinstance(other, Font) and self.name == other.name
if not isinstance(other, Font):
return NotImplemented
return self.name == other.name

def __getitem__(self, key):
return self.cget(key)
Expand Down
3 changes: 2 additions & 1 deletion Lib/tkinter/test/test_tkinter/test_font.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import tkinter
from tkinter import font
from test.support import requires, run_unittest, gc_collect
from test.support import requires, run_unittest, gc_collect, ALWAYS_EQ
from tkinter.test.support import AbstractTkTest

requires('gui')
Expand Down Expand Up @@ -70,6 +70,7 @@ def test_eq(self):
self.assertEqual(font1, font2)
self.assertNotEqual(font1, font1.copy())
self.assertNotEqual(font1, 0)
self.assertEqual(font1, ALWAYS_EQ)

def test_measure(self):
self.assertIsInstance(self.font.measure('abc'), int)
Expand Down
13 changes: 10 additions & 3 deletions Lib/tkinter/test/test_tkinter/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
from tkinter import (Variable, StringVar, IntVar, DoubleVar, BooleanVar, Tcl,
TclError)
from test.support import ALWAYS_EQ


class Var(Variable):
Expand Down Expand Up @@ -59,11 +60,17 @@ def test___eq__(self):
# values doesn't matter, only class and name are checked
v1 = Variable(self.root, name="abc")
v2 = Variable(self.root, name="abc")
self.assertIsNot(v1, v2)
self.assertEqual(v1, v2)

v3 = Variable(self.root, name="abc")
v4 = StringVar(self.root, name="abc")
self.assertNotEqual(v3, v4)
v3 = StringVar(self.root, name="abc")
self.assertNotEqual(v1, v3)

V = type('Variable', (), {})
self.assertNotEqual(v1, V())

self.assertNotEqual(v1, object())
self.assertEqual(v1, ALWAYS_EQ)

def test_invalid_name(self):
with self.assertRaises(TypeError):
Expand Down
4 changes: 3 additions & 1 deletion Lib/traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,9 @@ def _load_lines(self):
self.__cause__._load_lines()

def __eq__(self, other):
return self.__dict__ == other.__dict__
if isinstance(other, TracebackException):
return self.__dict__ == other.__dict__
return NotImplemented

def __str__(self):
return self._str
Expand Down
Loading

0 comments on commit 662db12

Please sign in to comment.