Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 135 additions & 98 deletions Lib/test/test_external_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,31 @@ def get_all_awaited_by(pid):
raise RuntimeError("Failed to get all awaited_by after retries")


def _get_stack_trace_with_retry(unwinder, timeout=SHORT_TIMEOUT, condition=None):
"""Get stack trace from an existing unwinder with retry for transient errors.

This handles the case where we want to reuse an existing RemoteUnwinder
instance but still handle transient failures like "Failed to parse initial
frame in chain" that can occur when sampling at an inopportune moment.
If condition is provided, keeps retrying until condition(traces) is True.
"""
last_error = None
for _ in busy_retry(timeout):
try:
traces = unwinder.get_stack_trace()
if condition is None or condition(traces):
return traces
# Condition not met yet, keep retrying
except TRANSIENT_ERRORS as e:
last_error = e
continue
if last_error:
raise RuntimeError(
f"Failed to get stack trace after retries: {last_error}"
)
raise RuntimeError("Condition never satisfied within timeout")


# ============================================================================
# Base test class with shared infrastructure
# ============================================================================
Expand Down Expand Up @@ -1708,16 +1733,16 @@ def main_work():

# Get stack trace with all threads
unwinder_all = RemoteUnwinder(p.pid, all_threads=True)
for _ in range(MAX_TRIES):
all_traces = unwinder_all.get_stack_trace()
found = self._find_frame_in_trace(
all_traces,
lambda f: f.funcname == "main_work"
and f.location.lineno > 12,
)
if found:
break
time.sleep(RETRY_DELAY)
for _ in busy_retry(SHORT_TIMEOUT):
with contextlib.suppress(*TRANSIENT_ERRORS):
all_traces = unwinder_all.get_stack_trace()
found = self._find_frame_in_trace(
all_traces,
lambda f: f.funcname == "main_work"
and f.location.lineno > 12,
)
if found:
break
else:
self.fail(
"Main thread did not start its busy work on time"
Expand All @@ -1727,7 +1752,7 @@ def main_work():
unwinder_gil = RemoteUnwinder(
p.pid, only_active_thread=True
)
gil_traces = unwinder_gil.get_stack_trace()
gil_traces = _get_stack_trace_with_retry(unwinder_gil)

# Count threads
total_threads = sum(
Expand Down Expand Up @@ -2002,15 +2027,15 @@ def busy():
mode=mode,
skip_non_matching_threads=False,
)
for _ in range(MAX_TRIES):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)
for _ in busy_retry(SHORT_TIMEOUT):
with contextlib.suppress(*TRANSIENT_ERRORS):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)

if check_condition(
statuses, sleeper_tid, busy_tid
):
break
time.sleep(0.5)
if check_condition(
statuses, sleeper_tid, busy_tid
):
break

return statuses, sleeper_tid, busy_tid
finally:
Expand Down Expand Up @@ -2154,29 +2179,29 @@ def busy_thread():
mode=PROFILING_MODE_ALL,
skip_non_matching_threads=False,
)
for _ in range(MAX_TRIES):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)

# Check ALL mode provides both GIL and CPU info
if (
sleeper_tid in statuses
and busy_tid in statuses
and not (
statuses[sleeper_tid]
& THREAD_STATUS_ON_CPU
)
and not (
statuses[sleeper_tid]
& THREAD_STATUS_HAS_GIL
)
and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
and (
statuses[busy_tid] & THREAD_STATUS_HAS_GIL
)
):
break
time.sleep(0.5)
for _ in busy_retry(SHORT_TIMEOUT):
with contextlib.suppress(*TRANSIENT_ERRORS):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)

# Check ALL mode provides both GIL and CPU info
if (
sleeper_tid in statuses
and busy_tid in statuses
and not (
statuses[sleeper_tid]
& THREAD_STATUS_ON_CPU
)
and not (
statuses[sleeper_tid]
& THREAD_STATUS_HAS_GIL
)
and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
and (
statuses[busy_tid] & THREAD_STATUS_HAS_GIL
)
):
break

self.assertIsNotNone(
sleeper_tid, "Sleeper thread id not received"
Expand Down Expand Up @@ -2300,18 +2325,18 @@ def test_thread_status_exception_detection(self):
mode=PROFILING_MODE_ALL,
skip_non_matching_threads=False,
)
for _ in range(MAX_TRIES):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)

if (
exception_tid in statuses
and normal_tid in statuses
and (statuses[exception_tid] & THREAD_STATUS_HAS_EXCEPTION)
and not (statuses[normal_tid] & THREAD_STATUS_HAS_EXCEPTION)
):
break
time.sleep(0.5)
for _ in busy_retry(SHORT_TIMEOUT):
with contextlib.suppress(*TRANSIENT_ERRORS):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)

if (
exception_tid in statuses
and normal_tid in statuses
and (statuses[exception_tid] & THREAD_STATUS_HAS_EXCEPTION)
and not (statuses[normal_tid] & THREAD_STATUS_HAS_EXCEPTION)
):
break

self.assertIn(exception_tid, statuses)
self.assertIn(normal_tid, statuses)
Expand Down Expand Up @@ -2343,18 +2368,18 @@ def test_thread_status_exception_mode_filtering(self):
mode=PROFILING_MODE_EXCEPTION,
skip_non_matching_threads=True,
)
for _ in range(MAX_TRIES):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)

if exception_tid in statuses:
self.assertNotIn(
normal_tid,
statuses,
"Normal thread should be filtered out in exception mode",
)
return
time.sleep(0.5)
for _ in busy_retry(SHORT_TIMEOUT):
with contextlib.suppress(*TRANSIENT_ERRORS):
traces = unwinder.get_stack_trace()
statuses = self._get_thread_statuses(traces)

if exception_tid in statuses:
self.assertNotIn(
normal_tid,
statuses,
"Normal thread should be filtered out in exception mode",
)
return

self.fail("Never found exception thread in exception mode")

Expand Down Expand Up @@ -2497,49 +2522,61 @@ def _run_scenario_process(self, scenario):
finally:
_cleanup_sockets(client_socket, server_socket)

def _check_exception_status(self, p, thread_tid, expect_exception):
"""Helper to check if thread has expected exception status."""
def _check_thread_status(
self, p, thread_tid, condition, condition_name="condition"
):
"""Helper to check thread status with a custom condition.

This waits until we see 3 consecutive samples where the condition
returns True, which confirms the thread has reached and is stable
in the expected state. Samples that don't match are ignored (the
thread may not have reached the expected state yet).

Args:
p: Process object with pid attribute
thread_tid: Thread ID to check
condition: Callable(statuses, thread_tid) -> bool that returns
True when the thread is in the expected state
condition_name: Description of condition for error messages
"""
unwinder = RemoteUnwinder(
p.pid,
all_threads=True,
mode=PROFILING_MODE_ALL,
skip_non_matching_threads=False,
)

# Collect multiple samples for reliability
results = []
for _ in range(MAX_TRIES):
try:
# Wait for 3 consecutive samples matching expected state
matching_samples = 0
for _ in busy_retry(SHORT_TIMEOUT):
with contextlib.suppress(*TRANSIENT_ERRORS):
traces = unwinder.get_stack_trace()
except TRANSIENT_ERRORS:
time.sleep(RETRY_DELAY)
continue
statuses = self._get_thread_statuses(traces)

if thread_tid in statuses:
has_exc = bool(statuses[thread_tid] & THREAD_STATUS_HAS_EXCEPTION)
results.append(has_exc)
statuses = self._get_thread_statuses(traces)

if len(results) >= 3:
break
if thread_tid in statuses:
if condition(statuses, thread_tid):
matching_samples += 1
if matching_samples >= 3:
return # Success - confirmed stable in expected state
else:
# Thread not yet in expected state, reset counter
matching_samples = 0

time.sleep(RETRY_DELAY)
self.fail(
f"Thread did not stabilize in expected state "
f"({condition_name}) within timeout"
)

# Check majority of samples match expected
if not results:
self.fail("Never found target thread in stack traces")
def _check_exception_status(self, p, thread_tid, expect_exception):
"""Helper to check if thread has expected exception status."""
def condition(statuses, tid):
has_exc = bool(statuses[tid] & THREAD_STATUS_HAS_EXCEPTION)
return has_exc == expect_exception

majority = sum(results) > len(results) // 2
if expect_exception:
self.assertTrue(
majority,
f"Thread should have HAS_EXCEPTION flag, got {results}"
)
else:
self.assertFalse(
majority,
f"Thread should NOT have HAS_EXCEPTION flag, got {results}"
)
self._check_thread_status(
p, thread_tid, condition,
condition_name=f"expect_exception={expect_exception}"
)

@unittest.skipIf(
sys.platform not in ("linux", "darwin", "win32"),
Expand Down Expand Up @@ -3445,7 +3482,7 @@ def test_get_stats(self):
_wait_for_signal(client_socket, b"ready")

# Take a sample
unwinder.get_stack_trace()
_get_stack_trace_with_retry(unwinder)

stats = unwinder.get_stats()
client_socket.sendall(b"done")
Expand Down
Loading