@@ -253,6 +253,31 @@ def get_all_awaited_by(pid):
253253 raise RuntimeError ("Failed to get all awaited_by after retries" )
254254
255255
256+ def _get_stack_trace_with_retry (unwinder , timeout = SHORT_TIMEOUT , condition = None ):
257+ """Get stack trace from an existing unwinder with retry for transient errors.
258+
259+ This handles the case where we want to reuse an existing RemoteUnwinder
260+ instance but still handle transient failures like "Failed to parse initial
261+ frame in chain" that can occur when sampling at an inopportune moment.
262+ If condition is provided, keeps retrying until condition(traces) is True.
263+ """
264+ last_error = None
265+ for _ in busy_retry (timeout ):
266+ try :
267+ traces = unwinder .get_stack_trace ()
268+ if condition is None or condition (traces ):
269+ return traces
270+ # Condition not met yet, keep retrying
271+ except TRANSIENT_ERRORS as e :
272+ last_error = e
273+ continue
274+ if last_error :
275+ raise RuntimeError (
276+ f"Failed to get stack trace after retries: { last_error } "
277+ )
278+ raise RuntimeError ("Condition never satisfied within timeout" )
279+
280+
256281# ============================================================================
257282# Base test class with shared infrastructure
258283# ============================================================================
@@ -1708,16 +1733,16 @@ def main_work():
17081733
17091734 # Get stack trace with all threads
17101735 unwinder_all = RemoteUnwinder (p .pid , all_threads = True )
1711- for _ in range ( MAX_TRIES ):
1712- all_traces = unwinder_all . get_stack_trace ()
1713- found = self . _find_frame_in_trace (
1714- all_traces ,
1715- lambda f : f . funcname == "main_work"
1716- and f . location . lineno > 12 ,
1717- )
1718- if found :
1719- break
1720- time . sleep ( RETRY_DELAY )
1736+ for _ in busy_retry ( SHORT_TIMEOUT ):
1737+ with contextlib . suppress ( * TRANSIENT_ERRORS ):
1738+ all_traces = unwinder_all . get_stack_trace ()
1739+ found = self . _find_frame_in_trace (
1740+ all_traces ,
1741+ lambda f : f . funcname == "main_work"
1742+ and f . location . lineno > 12 ,
1743+ )
1744+ if found :
1745+ break
17211746 else :
17221747 self .fail (
17231748 "Main thread did not start its busy work on time"
@@ -1727,7 +1752,7 @@ def main_work():
17271752 unwinder_gil = RemoteUnwinder (
17281753 p .pid , only_active_thread = True
17291754 )
1730- gil_traces = unwinder_gil . get_stack_trace ( )
1755+ gil_traces = _get_stack_trace_with_retry ( unwinder_gil )
17311756
17321757 # Count threads
17331758 total_threads = sum (
@@ -2002,15 +2027,15 @@ def busy():
20022027 mode = mode ,
20032028 skip_non_matching_threads = False ,
20042029 )
2005- for _ in range (MAX_TRIES ):
2006- traces = unwinder .get_stack_trace ()
2007- statuses = self ._get_thread_statuses (traces )
2030+ for _ in busy_retry (SHORT_TIMEOUT ):
2031+ with contextlib .suppress (* TRANSIENT_ERRORS ):
2032+ traces = unwinder .get_stack_trace ()
2033+ statuses = self ._get_thread_statuses (traces )
20082034
2009- if check_condition (
2010- statuses , sleeper_tid , busy_tid
2011- ):
2012- break
2013- time .sleep (0.5 )
2035+ if check_condition (
2036+ statuses , sleeper_tid , busy_tid
2037+ ):
2038+ break
20142039
20152040 return statuses , sleeper_tid , busy_tid
20162041 finally :
@@ -2154,29 +2179,29 @@ def busy_thread():
21542179 mode = PROFILING_MODE_ALL ,
21552180 skip_non_matching_threads = False ,
21562181 )
2157- for _ in range ( MAX_TRIES ):
2158- traces = unwinder . get_stack_trace ()
2159- statuses = self . _get_thread_statuses ( traces )
2160-
2161- # Check ALL mode provides both GIL and CPU info
2162- if (
2163- sleeper_tid in statuses
2164- and busy_tid in statuses
2165- and not (
2166- statuses [ sleeper_tid ]
2167- & THREAD_STATUS_ON_CPU
2168- )
2169- and not (
2170- statuses [ sleeper_tid ]
2171- & THREAD_STATUS_HAS_GIL
2172- )
2173- and ( statuses [ busy_tid ] & THREAD_STATUS_ON_CPU )
2174- and (
2175- statuses [ busy_tid ] & THREAD_STATUS_HAS_GIL
2176- )
2177- ):
2178- break
2179- time . sleep ( 0.5 )
2182+ for _ in busy_retry ( SHORT_TIMEOUT ):
2183+ with contextlib . suppress ( * TRANSIENT_ERRORS ):
2184+ traces = unwinder . get_stack_trace ( )
2185+ statuses = self . _get_thread_statuses ( traces )
2186+
2187+ # Check ALL mode provides both GIL and CPU info
2188+ if (
2189+ sleeper_tid in statuses
2190+ and busy_tid in statuses
2191+ and not (
2192+ statuses [ sleeper_tid ]
2193+ & THREAD_STATUS_ON_CPU
2194+ )
2195+ and not (
2196+ statuses [ sleeper_tid ]
2197+ & THREAD_STATUS_HAS_GIL
2198+ )
2199+ and (statuses [ busy_tid ] & THREAD_STATUS_ON_CPU )
2200+ and (
2201+ statuses [ busy_tid ] & THREAD_STATUS_HAS_GIL
2202+ )
2203+ ):
2204+ break
21802205
21812206 self .assertIsNotNone (
21822207 sleeper_tid , "Sleeper thread id not received"
@@ -2300,18 +2325,18 @@ def test_thread_status_exception_detection(self):
23002325 mode = PROFILING_MODE_ALL ,
23012326 skip_non_matching_threads = False ,
23022327 )
2303- for _ in range ( MAX_TRIES ):
2304- traces = unwinder . get_stack_trace ()
2305- statuses = self . _get_thread_statuses ( traces )
2306-
2307- if (
2308- exception_tid in statuses
2309- and normal_tid in statuses
2310- and ( statuses [ exception_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2311- and not (statuses [normal_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2312- ):
2313- break
2314- time . sleep ( 0.5 )
2328+ for _ in busy_retry ( SHORT_TIMEOUT ):
2329+ with contextlib . suppress ( * TRANSIENT_ERRORS ):
2330+ traces = unwinder . get_stack_trace ( )
2331+ statuses = self . _get_thread_statuses ( traces )
2332+
2333+ if (
2334+ exception_tid in statuses
2335+ and normal_tid in statuses
2336+ and (statuses [exception_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2337+ and not ( statuses [ normal_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2338+ ):
2339+ break
23152340
23162341 self .assertIn (exception_tid , statuses )
23172342 self .assertIn (normal_tid , statuses )
@@ -2343,18 +2368,18 @@ def test_thread_status_exception_mode_filtering(self):
23432368 mode = PROFILING_MODE_EXCEPTION ,
23442369 skip_non_matching_threads = True ,
23452370 )
2346- for _ in range ( MAX_TRIES ):
2347- traces = unwinder . get_stack_trace ()
2348- statuses = self . _get_thread_statuses ( traces )
2349-
2350- if exception_tid in statuses :
2351- self . assertNotIn (
2352- normal_tid ,
2353- statuses ,
2354- "Normal thread should be filtered out in exception mode" ,
2355- )
2356- return
2357- time . sleep ( 0.5 )
2371+ for _ in busy_retry ( SHORT_TIMEOUT ):
2372+ with contextlib . suppress ( * TRANSIENT_ERRORS ):
2373+ traces = unwinder . get_stack_trace ( )
2374+ statuses = self . _get_thread_statuses ( traces )
2375+
2376+ if exception_tid in statuses :
2377+ self . assertNotIn (
2378+ normal_tid ,
2379+ statuses ,
2380+ "Normal thread should be filtered out in exception mode" ,
2381+ )
2382+ return
23582383
23592384 self .fail ("Never found exception thread in exception mode" )
23602385
@@ -2497,49 +2522,61 @@ def _run_scenario_process(self, scenario):
24972522 finally :
24982523 _cleanup_sockets (client_socket , server_socket )
24992524
2500- def _check_exception_status (self , p , thread_tid , expect_exception ):
2501- """Helper to check if thread has expected exception status."""
2525+ def _check_thread_status (
2526+ self , p , thread_tid , condition , condition_name = "condition"
2527+ ):
2528+ """Helper to check thread status with a custom condition.
2529+
2530+ This waits until we see 3 consecutive samples where the condition
2531+ returns True, which confirms the thread has reached and is stable
2532+ in the expected state. Samples that don't match are ignored (the
2533+ thread may not have reached the expected state yet).
2534+
2535+ Args:
2536+ p: Process object with pid attribute
2537+ thread_tid: Thread ID to check
2538+ condition: Callable(statuses, thread_tid) -> bool that returns
2539+ True when the thread is in the expected state
2540+ condition_name: Description of condition for error messages
2541+ """
25022542 unwinder = RemoteUnwinder (
25032543 p .pid ,
25042544 all_threads = True ,
25052545 mode = PROFILING_MODE_ALL ,
25062546 skip_non_matching_threads = False ,
25072547 )
25082548
2509- # Collect multiple samples for reliability
2510- results = []
2511- for _ in range ( MAX_TRIES ):
2512- try :
2549+ # Wait for 3 consecutive samples matching expected state
2550+ matching_samples = 0
2551+ for _ in busy_retry ( SHORT_TIMEOUT ):
2552+ with contextlib . suppress ( * TRANSIENT_ERRORS ) :
25132553 traces = unwinder .get_stack_trace ()
2514- except TRANSIENT_ERRORS :
2515- time .sleep (RETRY_DELAY )
2516- continue
2517- statuses = self ._get_thread_statuses (traces )
2518-
2519- if thread_tid in statuses :
2520- has_exc = bool (statuses [thread_tid ] & THREAD_STATUS_HAS_EXCEPTION )
2521- results .append (has_exc )
2554+ statuses = self ._get_thread_statuses (traces )
25222555
2523- if len (results ) >= 3 :
2524- break
2556+ if thread_tid in statuses :
2557+ if condition (statuses , thread_tid ):
2558+ matching_samples += 1
2559+ if matching_samples >= 3 :
2560+ return # Success - confirmed stable in expected state
2561+ else :
2562+ # Thread not yet in expected state, reset counter
2563+ matching_samples = 0
25252564
2526- time .sleep (RETRY_DELAY )
2565+ self .fail (
2566+ f"Thread did not stabilize in expected state "
2567+ f"({ condition_name } ) within timeout"
2568+ )
25272569
2528- # Check majority of samples match expected
2529- if not results :
2530- self .fail ("Never found target thread in stack traces" )
2570+ def _check_exception_status (self , p , thread_tid , expect_exception ):
2571+ """Helper to check if thread has expected exception status."""
2572+ def condition (statuses , tid ):
2573+ has_exc = bool (statuses [tid ] & THREAD_STATUS_HAS_EXCEPTION )
2574+ return has_exc == expect_exception
25312575
2532- majority = sum (results ) > len (results ) // 2
2533- if expect_exception :
2534- self .assertTrue (
2535- majority ,
2536- f"Thread should have HAS_EXCEPTION flag, got { results } "
2537- )
2538- else :
2539- self .assertFalse (
2540- majority ,
2541- f"Thread should NOT have HAS_EXCEPTION flag, got { results } "
2542- )
2576+ self ._check_thread_status (
2577+ p , thread_tid , condition ,
2578+ condition_name = f"expect_exception={ expect_exception } "
2579+ )
25432580
25442581 @unittest .skipIf (
25452582 sys .platform not in ("linux" , "darwin" , "win32" ),
@@ -3445,7 +3482,7 @@ def test_get_stats(self):
34453482 _wait_for_signal (client_socket , b"ready" )
34463483
34473484 # Take a sample
3448- unwinder . get_stack_trace ( )
3485+ _get_stack_trace_with_retry ( unwinder )
34493486
34503487 stats = unwinder .get_stats ()
34513488 client_socket .sendall (b"done" )
0 commit comments