Skip to content

Commit 9af7a20

Browse files
authored
gh-136186: Fix flaky tests in test_external_inspection (#143110)
1 parent fc2f0fe commit 9af7a20

File tree

1 file changed

+135
-98
lines changed

1 file changed

+135
-98
lines changed

Lib/test/test_external_inspection.py

Lines changed: 135 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,31 @@ def get_all_awaited_by(pid):
253253
raise RuntimeError("Failed to get all awaited_by after retries")
254254

255255

256+
def _get_stack_trace_with_retry(unwinder, timeout=SHORT_TIMEOUT, condition=None):
257+
"""Get stack trace from an existing unwinder with retry for transient errors.
258+
259+
This handles the case where we want to reuse an existing RemoteUnwinder
260+
instance but still handle transient failures like "Failed to parse initial
261+
frame in chain" that can occur when sampling at an inopportune moment.
262+
If condition is provided, keeps retrying until condition(traces) is True.
263+
"""
264+
last_error = None
265+
for _ in busy_retry(timeout):
266+
try:
267+
traces = unwinder.get_stack_trace()
268+
if condition is None or condition(traces):
269+
return traces
270+
# Condition not met yet, keep retrying
271+
except TRANSIENT_ERRORS as e:
272+
last_error = e
273+
continue
274+
if last_error:
275+
raise RuntimeError(
276+
f"Failed to get stack trace after retries: {last_error}"
277+
)
278+
raise RuntimeError("Condition never satisfied within timeout")
279+
280+
256281
# ============================================================================
257282
# Base test class with shared infrastructure
258283
# ============================================================================
@@ -1708,16 +1733,16 @@ def main_work():
17081733

17091734
# Get stack trace with all threads
17101735
unwinder_all = RemoteUnwinder(p.pid, all_threads=True)
1711-
for _ in range(MAX_TRIES):
1712-
all_traces = unwinder_all.get_stack_trace()
1713-
found = self._find_frame_in_trace(
1714-
all_traces,
1715-
lambda f: f.funcname == "main_work"
1716-
and f.location.lineno > 12,
1717-
)
1718-
if found:
1719-
break
1720-
time.sleep(RETRY_DELAY)
1736+
for _ in busy_retry(SHORT_TIMEOUT):
1737+
with contextlib.suppress(*TRANSIENT_ERRORS):
1738+
all_traces = unwinder_all.get_stack_trace()
1739+
found = self._find_frame_in_trace(
1740+
all_traces,
1741+
lambda f: f.funcname == "main_work"
1742+
and f.location.lineno > 12,
1743+
)
1744+
if found:
1745+
break
17211746
else:
17221747
self.fail(
17231748
"Main thread did not start its busy work on time"
@@ -1727,7 +1752,7 @@ def main_work():
17271752
unwinder_gil = RemoteUnwinder(
17281753
p.pid, only_active_thread=True
17291754
)
1730-
gil_traces = unwinder_gil.get_stack_trace()
1755+
gil_traces = _get_stack_trace_with_retry(unwinder_gil)
17311756

17321757
# Count threads
17331758
total_threads = sum(
@@ -2002,15 +2027,15 @@ def busy():
20022027
mode=mode,
20032028
skip_non_matching_threads=False,
20042029
)
2005-
for _ in range(MAX_TRIES):
2006-
traces = unwinder.get_stack_trace()
2007-
statuses = self._get_thread_statuses(traces)
2030+
for _ in busy_retry(SHORT_TIMEOUT):
2031+
with contextlib.suppress(*TRANSIENT_ERRORS):
2032+
traces = unwinder.get_stack_trace()
2033+
statuses = self._get_thread_statuses(traces)
20082034

2009-
if check_condition(
2010-
statuses, sleeper_tid, busy_tid
2011-
):
2012-
break
2013-
time.sleep(0.5)
2035+
if check_condition(
2036+
statuses, sleeper_tid, busy_tid
2037+
):
2038+
break
20142039

20152040
return statuses, sleeper_tid, busy_tid
20162041
finally:
@@ -2154,29 +2179,29 @@ def busy_thread():
21542179
mode=PROFILING_MODE_ALL,
21552180
skip_non_matching_threads=False,
21562181
)
2157-
for _ in range(MAX_TRIES):
2158-
traces = unwinder.get_stack_trace()
2159-
statuses = self._get_thread_statuses(traces)
2160-
2161-
# Check ALL mode provides both GIL and CPU info
2162-
if (
2163-
sleeper_tid in statuses
2164-
and busy_tid in statuses
2165-
and not (
2166-
statuses[sleeper_tid]
2167-
& THREAD_STATUS_ON_CPU
2168-
)
2169-
and not (
2170-
statuses[sleeper_tid]
2171-
& THREAD_STATUS_HAS_GIL
2172-
)
2173-
and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
2174-
and (
2175-
statuses[busy_tid] & THREAD_STATUS_HAS_GIL
2176-
)
2177-
):
2178-
break
2179-
time.sleep(0.5)
2182+
for _ in busy_retry(SHORT_TIMEOUT):
2183+
with contextlib.suppress(*TRANSIENT_ERRORS):
2184+
traces = unwinder.get_stack_trace()
2185+
statuses = self._get_thread_statuses(traces)
2186+
2187+
# Check ALL mode provides both GIL and CPU info
2188+
if (
2189+
sleeper_tid in statuses
2190+
and busy_tid in statuses
2191+
and not (
2192+
statuses[sleeper_tid]
2193+
& THREAD_STATUS_ON_CPU
2194+
)
2195+
and not (
2196+
statuses[sleeper_tid]
2197+
& THREAD_STATUS_HAS_GIL
2198+
)
2199+
and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
2200+
and (
2201+
statuses[busy_tid] & THREAD_STATUS_HAS_GIL
2202+
)
2203+
):
2204+
break
21802205

21812206
self.assertIsNotNone(
21822207
sleeper_tid, "Sleeper thread id not received"
@@ -2300,18 +2325,18 @@ def test_thread_status_exception_detection(self):
23002325
mode=PROFILING_MODE_ALL,
23012326
skip_non_matching_threads=False,
23022327
)
2303-
for _ in range(MAX_TRIES):
2304-
traces = unwinder.get_stack_trace()
2305-
statuses = self._get_thread_statuses(traces)
2306-
2307-
if (
2308-
exception_tid in statuses
2309-
and normal_tid in statuses
2310-
and (statuses[exception_tid] & THREAD_STATUS_HAS_EXCEPTION)
2311-
and not (statuses[normal_tid] & THREAD_STATUS_HAS_EXCEPTION)
2312-
):
2313-
break
2314-
time.sleep(0.5)
2328+
for _ in busy_retry(SHORT_TIMEOUT):
2329+
with contextlib.suppress(*TRANSIENT_ERRORS):
2330+
traces = unwinder.get_stack_trace()
2331+
statuses = self._get_thread_statuses(traces)
2332+
2333+
if (
2334+
exception_tid in statuses
2335+
and normal_tid in statuses
2336+
and (statuses[exception_tid] & THREAD_STATUS_HAS_EXCEPTION)
2337+
and not (statuses[normal_tid] & THREAD_STATUS_HAS_EXCEPTION)
2338+
):
2339+
break
23152340

23162341
self.assertIn(exception_tid, statuses)
23172342
self.assertIn(normal_tid, statuses)
@@ -2343,18 +2368,18 @@ def test_thread_status_exception_mode_filtering(self):
23432368
mode=PROFILING_MODE_EXCEPTION,
23442369
skip_non_matching_threads=True,
23452370
)
2346-
for _ in range(MAX_TRIES):
2347-
traces = unwinder.get_stack_trace()
2348-
statuses = self._get_thread_statuses(traces)
2349-
2350-
if exception_tid in statuses:
2351-
self.assertNotIn(
2352-
normal_tid,
2353-
statuses,
2354-
"Normal thread should be filtered out in exception mode",
2355-
)
2356-
return
2357-
time.sleep(0.5)
2371+
for _ in busy_retry(SHORT_TIMEOUT):
2372+
with contextlib.suppress(*TRANSIENT_ERRORS):
2373+
traces = unwinder.get_stack_trace()
2374+
statuses = self._get_thread_statuses(traces)
2375+
2376+
if exception_tid in statuses:
2377+
self.assertNotIn(
2378+
normal_tid,
2379+
statuses,
2380+
"Normal thread should be filtered out in exception mode",
2381+
)
2382+
return
23582383

23592384
self.fail("Never found exception thread in exception mode")
23602385

@@ -2497,49 +2522,61 @@ def _run_scenario_process(self, scenario):
24972522
finally:
24982523
_cleanup_sockets(client_socket, server_socket)
24992524

2500-
def _check_exception_status(self, p, thread_tid, expect_exception):
2501-
"""Helper to check if thread has expected exception status."""
2525+
def _check_thread_status(
2526+
self, p, thread_tid, condition, condition_name="condition"
2527+
):
2528+
"""Helper to check thread status with a custom condition.
2529+
2530+
This waits until we see 3 consecutive samples where the condition
2531+
returns True, which confirms the thread has reached and is stable
2532+
in the expected state. Samples that don't match are ignored (the
2533+
thread may not have reached the expected state yet).
2534+
2535+
Args:
2536+
p: Process object with pid attribute
2537+
thread_tid: Thread ID to check
2538+
condition: Callable(statuses, thread_tid) -> bool that returns
2539+
True when the thread is in the expected state
2540+
condition_name: Description of condition for error messages
2541+
"""
25022542
unwinder = RemoteUnwinder(
25032543
p.pid,
25042544
all_threads=True,
25052545
mode=PROFILING_MODE_ALL,
25062546
skip_non_matching_threads=False,
25072547
)
25082548

2509-
# Collect multiple samples for reliability
2510-
results = []
2511-
for _ in range(MAX_TRIES):
2512-
try:
2549+
# Wait for 3 consecutive samples matching expected state
2550+
matching_samples = 0
2551+
for _ in busy_retry(SHORT_TIMEOUT):
2552+
with contextlib.suppress(*TRANSIENT_ERRORS):
25132553
traces = unwinder.get_stack_trace()
2514-
except TRANSIENT_ERRORS:
2515-
time.sleep(RETRY_DELAY)
2516-
continue
2517-
statuses = self._get_thread_statuses(traces)
2518-
2519-
if thread_tid in statuses:
2520-
has_exc = bool(statuses[thread_tid] & THREAD_STATUS_HAS_EXCEPTION)
2521-
results.append(has_exc)
2554+
statuses = self._get_thread_statuses(traces)
25222555

2523-
if len(results) >= 3:
2524-
break
2556+
if thread_tid in statuses:
2557+
if condition(statuses, thread_tid):
2558+
matching_samples += 1
2559+
if matching_samples >= 3:
2560+
return # Success - confirmed stable in expected state
2561+
else:
2562+
# Thread not yet in expected state, reset counter
2563+
matching_samples = 0
25252564

2526-
time.sleep(RETRY_DELAY)
2565+
self.fail(
2566+
f"Thread did not stabilize in expected state "
2567+
f"({condition_name}) within timeout"
2568+
)
25272569

2528-
# Check majority of samples match expected
2529-
if not results:
2530-
self.fail("Never found target thread in stack traces")
2570+
def _check_exception_status(self, p, thread_tid, expect_exception):
2571+
"""Helper to check if thread has expected exception status."""
2572+
def condition(statuses, tid):
2573+
has_exc = bool(statuses[tid] & THREAD_STATUS_HAS_EXCEPTION)
2574+
return has_exc == expect_exception
25312575

2532-
majority = sum(results) > len(results) // 2
2533-
if expect_exception:
2534-
self.assertTrue(
2535-
majority,
2536-
f"Thread should have HAS_EXCEPTION flag, got {results}"
2537-
)
2538-
else:
2539-
self.assertFalse(
2540-
majority,
2541-
f"Thread should NOT have HAS_EXCEPTION flag, got {results}"
2542-
)
2576+
self._check_thread_status(
2577+
p, thread_tid, condition,
2578+
condition_name=f"expect_exception={expect_exception}"
2579+
)
25432580

25442581
@unittest.skipIf(
25452582
sys.platform not in ("linux", "darwin", "win32"),
@@ -3445,7 +3482,7 @@ def test_get_stats(self):
34453482
_wait_for_signal(client_socket, b"ready")
34463483

34473484
# Take a sample
3448-
unwinder.get_stack_trace()
3485+
_get_stack_trace_with_retry(unwinder)
34493486

34503487
stats = unwinder.get_stats()
34513488
client_socket.sendall(b"done")

0 commit comments

Comments
 (0)