Skip to content

Commit 87939bd

Browse files
authored
gh-117657: Fix itertools.count thread safety (#119268)
Fix itertools.count in free-threading mode
1 parent 77ff28b commit 87939bd

File tree

3 files changed

+54
-11
lines changed

3 files changed

+54
-11
lines changed

Diff for: Lib/test/test_itertools.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def test_count(self):
546546
#check proper internal error handling for large "step' sizes
547547
count(1, maxsize+5); sys.exc_info()
548548

549-
def test_count_with_stride(self):
549+
def test_count_with_step(self):
550550
self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
551551
self.assertEqual(lzip('abc',count(start=2,step=3)),
552552
[('a', 2), ('b', 5), ('c', 8)])
@@ -590,6 +590,28 @@ def test_count_with_stride(self):
590590
self.assertEqual(type(next(c)), int)
591591
self.assertEqual(type(next(c)), float)
592592

593+
@threading_helper.requires_working_threading()
594+
def test_count_threading(self, step=1):
595+
# this test verifies multithreading consistency, which is
596+
# mostly for testing builds without GIL, but nice to test anyway
597+
count_to = 10_000
598+
num_threads = 10
599+
c = count(step=step)
600+
def counting_thread():
601+
for i in range(count_to):
602+
next(c)
603+
threads = []
604+
for i in range(num_threads):
605+
thread = threading.Thread(target=counting_thread)
606+
thread.start()
607+
threads.append(thread)
608+
for thread in threads:
609+
thread.join()
610+
self.assertEqual(next(c), count_to * num_threads * step)
611+
612+
def test_count_with_step_threading(self):
613+
self.test_count_threading(step=5)
614+
593615
def test_cycle(self):
594616
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
595617
self.assertEqual(list(cycle('')), [])

Diff for: Modules/itertoolsmodule.c

+31-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#include "Python.h"
2-
#include "pycore_call.h" // _PyObject_CallNoArgs()
3-
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
4-
#include "pycore_long.h" // _PyLong_GetZero()
5-
#include "pycore_moduleobject.h" // _PyModule_GetState()
6-
#include "pycore_typeobject.h" // _PyType_GetModuleState()
7-
#include "pycore_object.h" // _PyObject_GC_TRACK()
8-
#include "pycore_tuple.h" // _PyTuple_ITEMS()
2+
#include "pycore_call.h" // _PyObject_CallNoArgs()
3+
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
4+
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
5+
#include "pycore_long.h" // _PyLong_GetZero()
6+
#include "pycore_moduleobject.h" // _PyModule_GetState()
7+
#include "pycore_typeobject.h" // _PyType_GetModuleState()
8+
#include "pycore_object.h" // _PyObject_GC_TRACK()
9+
#include "pycore_tuple.h" // _PyTuple_ITEMS()
910

10-
#include <stddef.h> // offsetof()
11+
#include <stddef.h> // offsetof()
1112

1213
/* Itertools module written and maintained
1314
by Raymond D. Hettinger <python@rcn.com>
@@ -3254,7 +3255,7 @@ fast_mode: when cnt an integer < PY_SSIZE_T_MAX and no step is specified.
32543255
32553256
assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
32563257
Advances with: cnt += 1
3257-
When count hits Y_SSIZE_T_MAX, switch to slow_mode.
3258+
When count hits PY_SSIZE_T_MAX, switch to slow_mode.
32583259
32593260
slow_mode: when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.
32603261
@@ -3403,9 +3404,30 @@ count_nextlong(countobject *lz)
34033404
static PyObject *
34043405
count_next(countobject *lz)
34053406
{
3407+
#ifndef Py_GIL_DISABLED
34063408
if (lz->cnt == PY_SSIZE_T_MAX)
34073409
return count_nextlong(lz);
34083410
return PyLong_FromSsize_t(lz->cnt++);
3411+
#else
3412+
// free-threading version
3413+
// fast mode uses compare-exchange loop
3414+
// slow mode uses a critical section
3415+
PyObject *returned;
3416+
Py_ssize_t cnt;
3417+
3418+
cnt = _Py_atomic_load_ssize_relaxed(&lz->cnt);
3419+
for (;;) {
3420+
if (cnt == PY_SSIZE_T_MAX) {
3421+
Py_BEGIN_CRITICAL_SECTION(lz);
3422+
returned = count_nextlong(lz);
3423+
Py_END_CRITICAL_SECTION();
3424+
return returned;
3425+
}
3426+
if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
3427+
return PyLong_FromSsize_t(cnt);
3428+
}
3429+
}
3430+
#endif
34093431
}
34103432

34113433
static PyObject *

Diff for: Tools/tsan/suppressions_free_threading.txt

-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ race_top:_Py_dict_lookup_threadsafe
5656
race_top:_imp_release_lock
5757
race_top:_multiprocessing_SemLock_acquire_impl
5858
race_top:builtin_compile_impl
59-
race_top:count_next
6059
race_top:dictiter_new
6160
race_top:dictresize
6261
race_top:insert_to_emptydict

0 commit comments

Comments
 (0)