Skip to content

Commit 17e5264

Browse files
bpo-37685: Fixed comparisons of datetime.timedelta and datetime.timezone. (GH-14996)
There was a discrepancy between the Python and C implementations. Add singletons ALWAYS_EQ, LARGEST and SMALLEST in test.support to test mixed type comparison.
1 parent 5c72bad commit 17e5264

File tree

7 files changed

+107
-84
lines changed

7 files changed

+107
-84
lines changed

Doc/library/test.rst

+17
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,28 @@ The :mod:`test.support` module defines the following constants:
356356

357357
Check for presence of docstrings.
358358

359+
359360
.. data:: TEST_HTTP_URL
360361

361362
Define the URL of a dedicated HTTP server for the network tests.
362363

363364

365+
.. data:: ALWAYS_EQ
366+
367+
Object that is equal to anything. Used to test mixed type comparison.
368+
369+
370+
.. data:: LARGEST
371+
372+
Object that is greater than anything (except itself).
373+
Used to test mixed type comparison.
374+
375+
376+
.. data:: SMALLEST
377+
378+
Object that is less than anything (except itself).
379+
Used to test mixed type comparison.
380+
364381

365382
The :mod:`test.support` module defines the following functions:
366383

Lib/datetime.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -739,25 +739,25 @@ def __le__(self, other):
739739
if isinstance(other, timedelta):
740740
return self._cmp(other) <= 0
741741
else:
742-
_cmperror(self, other)
742+
return NotImplemented
743743

744744
def __lt__(self, other):
745745
if isinstance(other, timedelta):
746746
return self._cmp(other) < 0
747747
else:
748-
_cmperror(self, other)
748+
return NotImplemented
749749

750750
def __ge__(self, other):
751751
if isinstance(other, timedelta):
752752
return self._cmp(other) >= 0
753753
else:
754-
_cmperror(self, other)
754+
return NotImplemented
755755

756756
def __gt__(self, other):
757757
if isinstance(other, timedelta):
758758
return self._cmp(other) > 0
759759
else:
760-
_cmperror(self, other)
760+
return NotImplemented
761761

762762
def _cmp(self, other):
763763
assert isinstance(other, timedelta)
@@ -1316,25 +1316,25 @@ def __le__(self, other):
13161316
if isinstance(other, time):
13171317
return self._cmp(other) <= 0
13181318
else:
1319-
_cmperror(self, other)
1319+
return NotImplemented
13201320

13211321
def __lt__(self, other):
13221322
if isinstance(other, time):
13231323
return self._cmp(other) < 0
13241324
else:
1325-
_cmperror(self, other)
1325+
return NotImplemented
13261326

13271327
def __ge__(self, other):
13281328
if isinstance(other, time):
13291329
return self._cmp(other) >= 0
13301330
else:
1331-
_cmperror(self, other)
1331+
return NotImplemented
13321332

13331333
def __gt__(self, other):
13341334
if isinstance(other, time):
13351335
return self._cmp(other) > 0
13361336
else:
1337-
_cmperror(self, other)
1337+
return NotImplemented
13381338

13391339
def _cmp(self, other, allow_mixed=False):
13401340
assert isinstance(other, time)
@@ -2210,9 +2210,9 @@ def __getinitargs__(self):
22102210
return (self._offset, self._name)
22112211

22122212
def __eq__(self, other):
2213-
if type(other) != timezone:
2214-
return False
2215-
return self._offset == other._offset
2213+
if isinstance(other, timezone):
2214+
return self._offset == other._offset
2215+
return NotImplemented
22162216

22172217
def __hash__(self):
22182218
return hash(self._offset)

Lib/test/datetimetester.py

+30-44
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22
33
See http://www.zope.org/Members/fdrake/DateTimeWiki/TestCases
44
"""
5-
from test.support import is_resource_enabled
6-
75
import itertools
86
import bisect
9-
107
import copy
118
import decimal
129
import sys
@@ -22,6 +19,7 @@
2219
from operator import lt, le, gt, ge, eq, ne, truediv, floordiv, mod
2320

2421
from test import support
22+
from test.support import is_resource_enabled, ALWAYS_EQ, LARGEST, SMALLEST
2523

2624
import datetime as datetime_module
2725
from datetime import MINYEAR, MAXYEAR
@@ -54,18 +52,6 @@
5452
NAN = float("nan")
5553

5654

57-
class ComparesEqualClass(object):
58-
"""
59-
A class that is always equal to whatever you compare it to.
60-
"""
61-
62-
def __eq__(self, other):
63-
return True
64-
65-
def __ne__(self, other):
66-
return False
67-
68-
6955
#############################################################################
7056
# module tests
7157

@@ -353,6 +339,18 @@ def test_comparison(self):
353339
self.assertTrue(timezone(ZERO) != None)
354340
self.assertFalse(timezone(ZERO) == None)
355341

342+
tz = timezone(ZERO)
343+
self.assertTrue(tz == ALWAYS_EQ)
344+
self.assertFalse(tz != ALWAYS_EQ)
345+
self.assertTrue(tz < LARGEST)
346+
self.assertFalse(tz > LARGEST)
347+
self.assertTrue(tz <= LARGEST)
348+
self.assertFalse(tz >= LARGEST)
349+
self.assertFalse(tz < SMALLEST)
350+
self.assertTrue(tz > SMALLEST)
351+
self.assertFalse(tz <= SMALLEST)
352+
self.assertTrue(tz >= SMALLEST)
353+
356354
def test_aware_datetime(self):
357355
# test that timezone instances can be used by datetime
358356
t = datetime(1, 1, 1)
@@ -414,10 +412,21 @@ def test_harmless_mixed_comparison(self):
414412

415413
# Comparison to objects of unsupported types should return
416414
# NotImplemented which falls back to the right hand side's __eq__
417-
# method. In this case, ComparesEqualClass.__eq__ always returns True.
418-
# ComparesEqualClass.__ne__ always returns False.
419-
self.assertTrue(me == ComparesEqualClass())
420-
self.assertFalse(me != ComparesEqualClass())
415+
# method. In this case, ALWAYS_EQ.__eq__ always returns True.
416+
# ALWAYS_EQ.__ne__ always returns False.
417+
self.assertTrue(me == ALWAYS_EQ)
418+
self.assertFalse(me != ALWAYS_EQ)
419+
420+
# If the other class explicitly defines ordering
421+
# relative to our class, it is allowed to do so
422+
self.assertTrue(me < LARGEST)
423+
self.assertFalse(me > LARGEST)
424+
self.assertTrue(me <= LARGEST)
425+
self.assertFalse(me >= LARGEST)
426+
self.assertFalse(me < SMALLEST)
427+
self.assertTrue(me > SMALLEST)
428+
self.assertFalse(me <= SMALLEST)
429+
self.assertTrue(me >= SMALLEST)
421430

422431
def test_harmful_mixed_comparison(self):
423432
me = self.theclass(1, 1, 1)
@@ -1582,29 +1591,6 @@ class SomeClass:
15821591
self.assertRaises(TypeError, lambda: our < their)
15831592
self.assertRaises(TypeError, lambda: their < our)
15841593

1585-
# However, if the other class explicitly defines ordering
1586-
# relative to our class, it is allowed to do so
1587-
1588-
class LargerThanAnything:
1589-
def __lt__(self, other):
1590-
return False
1591-
def __le__(self, other):
1592-
return isinstance(other, LargerThanAnything)
1593-
def __eq__(self, other):
1594-
return isinstance(other, LargerThanAnything)
1595-
def __gt__(self, other):
1596-
return not isinstance(other, LargerThanAnything)
1597-
def __ge__(self, other):
1598-
return True
1599-
1600-
their = LargerThanAnything()
1601-
self.assertEqual(our == their, False)
1602-
self.assertEqual(their == our, False)
1603-
self.assertEqual(our != their, True)
1604-
self.assertEqual(their != our, True)
1605-
self.assertEqual(our < their, True)
1606-
self.assertEqual(their < our, False)
1607-
16081594
def test_bool(self):
16091595
# All dates are considered true.
16101596
self.assertTrue(self.theclass.min)
@@ -3781,8 +3767,8 @@ def test_replace(self):
37813767
self.assertRaises(ValueError, base.replace, microsecond=1000000)
37823768

37833769
def test_mixed_compare(self):
3784-
t1 = time(1, 2, 3)
3785-
t2 = time(1, 2, 3)
3770+
t1 = self.theclass(1, 2, 3)
3771+
t2 = self.theclass(1, 2, 3)
37863772
self.assertEqual(t1, t2)
37873773
t2 = t2.replace(tzinfo=None)
37883774
self.assertEqual(t1, t2)

Lib/test/support/__init__.py

+36
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
"run_with_locale", "swap_item",
114114
"swap_attr", "Matcher", "set_memlimit", "SuppressCrashReport", "sortdict",
115115
"run_with_tz", "PGO", "missing_compiler_executable", "fd_count",
116+
"ALWAYS_EQ", "LARGEST", "SMALLEST"
116117
]
117118

118119
class Error(Exception):
@@ -3103,6 +3104,41 @@ def __fspath__(self):
31033104
return self.path
31043105

31053106

3107+
class _ALWAYS_EQ:
3108+
"""
3109+
Object that is equal to anything.
3110+
"""
3111+
def __eq__(self, other):
3112+
return True
3113+
def __ne__(self, other):
3114+
return False
3115+
3116+
ALWAYS_EQ = _ALWAYS_EQ()
3117+
3118+
@functools.total_ordering
3119+
class _LARGEST:
3120+
"""
3121+
Object that is greater than anything (except itself).
3122+
"""
3123+
def __eq__(self, other):
3124+
return isinstance(other, _LARGEST)
3125+
def __lt__(self, other):
3126+
return False
3127+
3128+
LARGEST = _LARGEST()
3129+
3130+
@functools.total_ordering
3131+
class _SMALLEST:
3132+
"""
3133+
Object that is less than anything (except itself).
3134+
"""
3135+
def __eq__(self, other):
3136+
return isinstance(other, _SMALLEST)
3137+
def __gt__(self, other):
3138+
return False
3139+
3140+
SMALLEST = _SMALLEST()
3141+
31063142
def maybe_get_event_loop_policy():
31073143
"""Return the global event loop policy if one is set, else return None."""
31083144
return asyncio.events._event_loop_policy

Lib/test/test_ipaddress.py

+9-24
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pickle
1313
import ipaddress
1414
import weakref
15+
from test.support import LARGEST, SMALLEST
1516

1617

1718
class BaseTestCase(unittest.TestCase):
@@ -673,20 +674,6 @@ def test_ip_network(self):
673674
self.assertFactoryError(ipaddress.ip_network, "network")
674675

675676

676-
@functools.total_ordering
677-
class LargestObject:
678-
def __eq__(self, other):
679-
return isinstance(other, LargestObject)
680-
def __lt__(self, other):
681-
return False
682-
683-
@functools.total_ordering
684-
class SmallestObject:
685-
def __eq__(self, other):
686-
return isinstance(other, SmallestObject)
687-
def __gt__(self, other):
688-
return False
689-
690677
class ComparisonTests(unittest.TestCase):
691678

692679
v4addr = ipaddress.IPv4Address(1)
@@ -775,8 +762,6 @@ def test_mixed_type_ordering(self):
775762

776763
def test_foreign_type_ordering(self):
777764
other = object()
778-
smallest = SmallestObject()
779-
largest = LargestObject()
780765
for obj in self.objects:
781766
with self.assertRaises(TypeError):
782767
obj < other
@@ -786,14 +771,14 @@ def test_foreign_type_ordering(self):
786771
obj <= other
787772
with self.assertRaises(TypeError):
788773
obj >= other
789-
self.assertTrue(obj < largest)
790-
self.assertFalse(obj > largest)
791-
self.assertTrue(obj <= largest)
792-
self.assertFalse(obj >= largest)
793-
self.assertFalse(obj < smallest)
794-
self.assertTrue(obj > smallest)
795-
self.assertFalse(obj <= smallest)
796-
self.assertTrue(obj >= smallest)
774+
self.assertTrue(obj < LARGEST)
775+
self.assertFalse(obj > LARGEST)
776+
self.assertTrue(obj <= LARGEST)
777+
self.assertFalse(obj >= LARGEST)
778+
self.assertFalse(obj < SMALLEST)
779+
self.assertTrue(obj > SMALLEST)
780+
self.assertFalse(obj <= SMALLEST)
781+
self.assertTrue(obj >= SMALLEST)
797782

798783
def test_mixed_type_key(self):
799784
# with get_mixed_type_key, you can sort addresses and network.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fixed comparisons of :class:`datetime.timedelta` and
2+
:class:`datetime.timezone`.

Modules/_datetimemodule.c

+2-5
Original file line numberDiff line numberDiff line change
@@ -3741,11 +3741,8 @@ timezone_richcompare(PyDateTime_TimeZone *self,
37413741
{
37423742
if (op != Py_EQ && op != Py_NE)
37433743
Py_RETURN_NOTIMPLEMENTED;
3744-
if (Py_TYPE(other) != &PyDateTime_TimeZoneType) {
3745-
if (op == Py_EQ)
3746-
Py_RETURN_FALSE;
3747-
else
3748-
Py_RETURN_TRUE;
3744+
if (!PyTZInfo_Check(other)) {
3745+
Py_RETURN_NOTIMPLEMENTED;
37493746
}
37503747
return delta_richcompare(self->offset, other->offset, op);
37513748
}

0 commit comments

Comments
 (0)