| import contextlib |
| import unittest |
| import os |
| import textwrap |
| import importlib |
| import sys |
| import socket |
| import threading |
| import time |
| from contextlib import contextmanager |
| from asyncio import staggered, taskgroups, base_events, tasks |
| from unittest.mock import ANY |
| from test.support import ( |
| os_helper, |
| SHORT_TIMEOUT, |
| busy_retry, |
| requires_gil_enabled, |
| ) |
| from test.support.script_helper import make_script |
| from test.support.socket_helper import find_unused_port |
| |
| import subprocess |
| |
| # Profiling mode constants |
| PROFILING_MODE_WALL = 0 |
| PROFILING_MODE_CPU = 1 |
| PROFILING_MODE_GIL = 2 |
| PROFILING_MODE_ALL = 3 |
| |
| # Thread status flags |
| THREAD_STATUS_HAS_GIL = 1 << 0 |
| THREAD_STATUS_ON_CPU = 1 << 1 |
| THREAD_STATUS_UNKNOWN = 1 << 2 |
| |
| # Maximum number of retry attempts for operations that may fail transiently |
| MAX_TRIES = 10 |
| |
| try: |
| from concurrent import interpreters |
| except ImportError: |
| interpreters = None |
| |
| PROCESS_VM_READV_SUPPORTED = False |
| |
| try: |
| from _remote_debugging import PROCESS_VM_READV_SUPPORTED |
| from _remote_debugging import RemoteUnwinder |
| from _remote_debugging import FrameInfo, CoroInfo, TaskInfo |
| except ImportError: |
| raise unittest.SkipTest( |
| "Test only runs when _remote_debugging is available" |
| ) |
| |
| |
| # ============================================================================ |
| # Module-level helper functions |
| # ============================================================================ |
| |
| |
| def _make_test_script(script_dir, script_basename, source): |
| to_return = make_script(script_dir, script_basename, source) |
| importlib.invalidate_caches() |
| return to_return |
| |
| |
| def _create_server_socket(port, backlog=1): |
| """Create and configure a server socket for test communication.""" |
| server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| server_socket.bind(("localhost", port)) |
| server_socket.settimeout(SHORT_TIMEOUT) |
| server_socket.listen(backlog) |
| return server_socket |
| |
| |
| def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT): |
| """ |
| Wait for expected signal(s) from a socket with proper timeout and EOF handling. |
| |
| Args: |
| sock: Connected socket to read from |
| expected_signals: Single bytes object or list of bytes objects to wait for |
| timeout: Socket timeout in seconds |
| |
| Returns: |
| bytes: Complete accumulated response buffer |
| |
| Raises: |
| RuntimeError: If connection closed before signal received or timeout |
| """ |
| if isinstance(expected_signals, bytes): |
| expected_signals = [expected_signals] |
| |
| sock.settimeout(timeout) |
| buffer = b"" |
| |
| while True: |
| # Check if all expected signals are in buffer |
| if all(sig in buffer for sig in expected_signals): |
| return buffer |
| |
| try: |
| chunk = sock.recv(4096) |
| if not chunk: |
| # EOF - connection closed |
| raise RuntimeError( |
| f"Connection closed before receiving expected signals. " |
| f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" |
| ) |
| buffer += chunk |
| except socket.timeout: |
| raise RuntimeError( |
| f"Timeout waiting for signals. " |
| f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" |
| ) |
| |
| |
| def _wait_for_n_signals(sock, signal_pattern, count, timeout=SHORT_TIMEOUT): |
| """ |
| Wait for N occurrences of a signal pattern. |
| |
| Args: |
| sock: Connected socket to read from |
| signal_pattern: bytes pattern to count (e.g., b"ready") |
| count: Number of occurrences expected |
| timeout: Socket timeout in seconds |
| |
| Returns: |
| bytes: Complete accumulated response buffer |
| |
| Raises: |
| RuntimeError: If connection closed or timeout before receiving all signals |
| """ |
| sock.settimeout(timeout) |
| buffer = b"" |
| found_count = 0 |
| |
| while found_count < count: |
| try: |
| chunk = sock.recv(4096) |
| if not chunk: |
| raise RuntimeError( |
| f"Connection closed after {found_count}/{count} signals. " |
| f"Last 200 bytes: {buffer[-200:]!r}" |
| ) |
| buffer += chunk |
| # Count occurrences in entire buffer |
| found_count = buffer.count(signal_pattern) |
| except socket.timeout: |
| raise RuntimeError( |
| f"Timeout waiting for {count} signals (found {found_count}). " |
| f"Last 200 bytes: {buffer[-200:]!r}" |
| ) |
| |
| return buffer |
| |
| |
| @contextmanager |
| def _managed_subprocess(args, timeout=SHORT_TIMEOUT): |
| """ |
| Context manager for subprocess lifecycle management. |
| |
| Ensures process is properly terminated and cleaned up even on exceptions. |
| Uses graceful termination first, then forceful kill if needed. |
| """ |
| p = subprocess.Popen(args) |
| try: |
| yield p |
| finally: |
| try: |
| p.terminate() |
| try: |
| p.wait(timeout=timeout) |
| except subprocess.TimeoutExpired: |
| p.kill() |
| try: |
| p.wait(timeout=timeout) |
| except subprocess.TimeoutExpired: |
| pass # Process refuses to die, nothing more we can do |
| except OSError: |
| pass # Process already dead |
| |
| |
| def _cleanup_sockets(*sockets): |
| """Safely close multiple sockets, ignoring errors.""" |
| for sock in sockets: |
| if sock is not None: |
| try: |
| sock.close() |
| except OSError: |
| pass |
| |
| |
| # ============================================================================ |
| # Decorators and skip conditions |
| # ============================================================================ |
| |
| skip_if_not_supported = unittest.skipIf( |
| ( |
| sys.platform != "darwin" |
| and sys.platform != "linux" |
| and sys.platform != "win32" |
| ), |
| "Test only runs on Linux, Windows and MacOS", |
| ) |
| |
| |
| def requires_subinterpreters(meth): |
| """Decorator to skip a test if subinterpreters are not supported.""" |
| return unittest.skipIf(interpreters is None, "subinterpreters required")( |
| meth |
| ) |
| |
| |
| # ============================================================================ |
| # Simple wrapper functions for RemoteUnwinder |
| # ============================================================================ |
| |
| # Errors that can occur transiently when reading process memory without synchronization |
| RETRIABLE_ERRORS = ( |
| "Task list appears corrupted", |
| "Invalid linked list structure reading remote memory", |
| "Unknown error reading memory", |
| "Unhandled frame owner", |
| "Failed to parse initial frame", |
| "Failed to process frame chain", |
| "Failed to unwind stack", |
| ) |
| |
| |
| def _is_retriable_error(exc): |
| """Check if an exception is a transient error that should be retried.""" |
| msg = str(exc) |
| return any(msg.startswith(err) or err in msg for err in RETRIABLE_ERRORS) |
| |
| |
| def get_stack_trace(pid): |
| for _ in busy_retry(SHORT_TIMEOUT): |
| try: |
| unwinder = RemoteUnwinder(pid, all_threads=True, debug=True) |
| return unwinder.get_stack_trace() |
| except RuntimeError as e: |
| if _is_retriable_error(e): |
| continue |
| raise |
| raise RuntimeError("Failed to get stack trace after retries") |
| |
| |
| def get_async_stack_trace(pid): |
| for _ in busy_retry(SHORT_TIMEOUT): |
| try: |
| unwinder = RemoteUnwinder(pid, debug=True) |
| return unwinder.get_async_stack_trace() |
| except RuntimeError as e: |
| if _is_retriable_error(e): |
| continue |
| raise |
| raise RuntimeError("Failed to get async stack trace after retries") |
| |
| |
| def get_all_awaited_by(pid): |
| for _ in busy_retry(SHORT_TIMEOUT): |
| try: |
| unwinder = RemoteUnwinder(pid, debug=True) |
| return unwinder.get_all_awaited_by() |
| except RuntimeError as e: |
| if _is_retriable_error(e): |
| continue |
| raise |
| raise RuntimeError("Failed to get all awaited_by after retries") |
| |
| |
| # ============================================================================ |
| # Base test class with shared infrastructure |
| # ============================================================================ |
| |
| |
| class RemoteInspectionTestBase(unittest.TestCase): |
| """Base class for remote inspection tests with common helpers.""" |
| |
| maxDiff = None |
| |
| def _run_script_and_get_trace( |
| self, |
| script, |
| trace_func, |
| wait_for_signals=None, |
| port=None, |
| backlog=1, |
| ): |
| """ |
| Common pattern: run a script, wait for signals, get trace. |
| |
| Args: |
| script: Script content (will be formatted with port if {port} present) |
| trace_func: Function to call with pid to get trace (e.g., get_stack_trace) |
| wait_for_signals: Signal(s) to wait for before getting trace |
| port: Port to use (auto-selected if None) |
| backlog: Socket listen backlog |
| |
| Returns: |
| tuple: (trace_result, script_name) |
| """ |
| if port is None: |
| port = find_unused_port() |
| |
| # Format script with port if needed |
| if "{port}" in script or "{{port}}" in script: |
| script = script.replace("{{port}}", "{port}").format(port=port) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port, backlog) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| if wait_for_signals: |
| _wait_for_signal(client_socket, wait_for_signals) |
| |
| try: |
| trace = trace_func(p.pid) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| return trace, script_name |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| def _find_frame_in_trace(self, stack_trace, predicate): |
| """ |
| Find a frame matching predicate in stack trace. |
| |
| Args: |
| stack_trace: List of InterpreterInfo objects |
| predicate: Function(frame) -> bool |
| |
| Returns: |
| FrameInfo or None |
| """ |
| for interpreter_info in stack_trace: |
| for thread_info in interpreter_info.threads: |
| for frame in thread_info.frame_info: |
| if predicate(frame): |
| return frame |
| return None |
| |
| def _find_thread_by_id(self, stack_trace, thread_id): |
| """Find a thread by its native thread ID.""" |
| for interpreter_info in stack_trace: |
| for thread_info in interpreter_info.threads: |
| if thread_info.thread_id == thread_id: |
| return thread_info |
| return None |
| |
| def _find_thread_with_frame(self, stack_trace, frame_predicate): |
| """Find a thread containing a frame matching predicate.""" |
| for interpreter_info in stack_trace: |
| for thread_info in interpreter_info.threads: |
| for frame in thread_info.frame_info: |
| if frame_predicate(frame): |
| return thread_info |
| return None |
| |
| def _get_thread_statuses(self, stack_trace): |
| """Extract thread_id -> status mapping from stack trace.""" |
| statuses = {} |
| for interpreter_info in stack_trace: |
| for thread_info in interpreter_info.threads: |
| statuses[thread_info.thread_id] = thread_info.status |
| return statuses |
| |
| def _get_task_id_map(self, stack_trace): |
| """Create task_id -> task mapping from async stack trace.""" |
| return {task.task_id: task for task in stack_trace[0].awaited_by} |
| |
| def _get_awaited_by_relationships(self, stack_trace): |
| """Extract task name to awaited_by set mapping.""" |
| id_to_task = self._get_task_id_map(stack_trace) |
| return { |
| task.task_name: set( |
| id_to_task[awaited.task_name].task_name |
| for awaited in task.awaited_by |
| ) |
| for task in stack_trace[0].awaited_by |
| } |
| |
| def _extract_coroutine_stacks(self, stack_trace): |
| """Extract and format coroutine stacks from tasks.""" |
| return { |
| task.task_name: sorted( |
| tuple(tuple(frame) for frame in coro.call_stack) |
| for coro in task.coroutine_stack |
| ) |
| for task in stack_trace[0].awaited_by |
| } |
| |
| |
| # ============================================================================ |
| # Test classes |
| # ============================================================================ |
| |
| |
| class TestGetStackTrace(RemoteInspectionTestBase): |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_remote_stack_trace(self): |
| port = find_unused_port() |
| script = textwrap.dedent( |
| f"""\ |
| import time, sys, socket, threading |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| def bar(): |
| for x in range(100): |
| if x == 50: |
| baz() |
| |
| def baz(): |
| foo() |
| |
| def foo(): |
| sock.sendall(b"ready:thread\\n"); time.sleep(10_000) |
| |
| t = threading.Thread(target=bar) |
| t.start() |
| sock.sendall(b"ready:main\\n"); t.join() |
| """ |
| ) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| _wait_for_signal( |
| client_socket, [b"ready:main", b"ready:thread"] |
| ) |
| |
| try: |
| stack_trace = get_stack_trace(p.pid) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| thread_expected_stack_trace = [ |
| FrameInfo([script_name, 15, "foo"]), |
| FrameInfo([script_name, 12, "baz"]), |
| FrameInfo([script_name, 9, "bar"]), |
| FrameInfo([threading.__file__, ANY, "Thread.run"]), |
| FrameInfo( |
| [ |
| threading.__file__, |
| ANY, |
| "Thread._bootstrap_inner", |
| ] |
| ), |
| FrameInfo( |
| [threading.__file__, ANY, "Thread._bootstrap"] |
| ), |
| ] |
| |
| # Find expected thread stack |
| found_thread = self._find_thread_with_frame( |
| stack_trace, |
| lambda f: f.funcname == "foo" and f.lineno == 15, |
| ) |
| self.assertIsNotNone( |
| found_thread, "Expected thread stack trace not found" |
| ) |
| self.assertEqual( |
| found_thread.frame_info, thread_expected_stack_trace |
| ) |
| |
| # Check main thread |
| main_frame = FrameInfo([script_name, 19, "<module>"]) |
| found_main = self._find_frame_in_trace( |
| stack_trace, lambda f: f == main_frame |
| ) |
| self.assertIsNotNone( |
| found_main, "Main thread stack trace not found" |
| ) |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_async_remote_stack_trace(self): |
| port = find_unused_port() |
| script = textwrap.dedent( |
| f"""\ |
| import asyncio |
| import time |
| import sys |
| import socket |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| def c5(): |
| sock.sendall(b"ready"); time.sleep(10_000) |
| |
| async def c4(): |
| await asyncio.sleep(0) |
| c5() |
| |
| async def c3(): |
| await c4() |
| |
| async def c2(): |
| await c3() |
| |
| async def c1(task): |
| await task |
| |
| async def main(): |
| async with asyncio.TaskGroup() as tg: |
| task = tg.create_task(c2(), name="c2_root") |
| tg.create_task(c1(task), name="sub_main_1") |
| tg.create_task(c1(task), name="sub_main_2") |
| |
| def new_eager_loop(): |
| loop = asyncio.new_event_loop() |
| eager_task_factory = asyncio.create_eager_task_factory( |
| asyncio.Task) |
| loop.set_task_factory(eager_task_factory) |
| return loop |
| |
| asyncio.run(main(), loop_factory={{TASK_FACTORY}}) |
| """ |
| ) |
| |
| for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop": |
| with ( |
| self.subTest(task_factory_variant=task_factory_variant), |
| os_helper.temp_dir() as work_dir, |
| ): |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script( |
| script_dir, |
| "script", |
| script.format(TASK_FACTORY=task_factory_variant), |
| ) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess( |
| [sys.executable, script_name] |
| ) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| response = _wait_for_signal(client_socket, b"ready") |
| self.assertIn(b"ready", response) |
| |
| try: |
| stack_trace = get_async_stack_trace(p.pid) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| # Check all tasks are present |
| tasks_names = [ |
| task.task_name |
| for task in stack_trace[0].awaited_by |
| ] |
| for task_name in [ |
| "c2_root", |
| "sub_main_1", |
| "sub_main_2", |
| ]: |
| self.assertIn(task_name, tasks_names) |
| |
| # Check awaited_by relationships |
| relationships = self._get_awaited_by_relationships( |
| stack_trace |
| ) |
| self.assertEqual( |
| relationships, |
| { |
| "c2_root": { |
| "Task-1", |
| "sub_main_1", |
| "sub_main_2", |
| }, |
| "Task-1": set(), |
| "sub_main_1": {"Task-1"}, |
| "sub_main_2": {"Task-1"}, |
| }, |
| ) |
| |
| # Check coroutine stacks |
| coroutine_stacks = self._extract_coroutine_stacks( |
| stack_trace |
| ) |
| self.assertEqual( |
| coroutine_stacks, |
| { |
| "Task-1": [ |
| ( |
| tuple( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup._aexit", |
| ] |
| ), |
| tuple( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup.__aexit__", |
| ] |
| ), |
| tuple([script_name, 26, "main"]), |
| ) |
| ], |
| "c2_root": [ |
| ( |
| tuple([script_name, 10, "c5"]), |
| tuple([script_name, 14, "c4"]), |
| tuple([script_name, 17, "c3"]), |
| tuple([script_name, 20, "c2"]), |
| ) |
| ], |
| "sub_main_1": [ |
| (tuple([script_name, 23, "c1"]),) |
| ], |
| "sub_main_2": [ |
| (tuple([script_name, 23, "c1"]),) |
| ], |
| }, |
| ) |
| |
| # Check awaited_by coroutine stacks |
| id_to_task = self._get_task_id_map(stack_trace) |
| awaited_by_coroutine_stacks = { |
| task.task_name: sorted( |
| ( |
| id_to_task[coro.task_name].task_name, |
| tuple( |
| tuple(frame) |
| for frame in coro.call_stack |
| ), |
| ) |
| for coro in task.awaited_by |
| ) |
| for task in stack_trace[0].awaited_by |
| } |
| self.assertEqual( |
| awaited_by_coroutine_stacks, |
| { |
| "Task-1": [], |
| "c2_root": [ |
| ( |
| "Task-1", |
| ( |
| tuple( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup._aexit", |
| ] |
| ), |
| tuple( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup.__aexit__", |
| ] |
| ), |
| tuple([script_name, 26, "main"]), |
| ), |
| ), |
| ( |
| "sub_main_1", |
| (tuple([script_name, 23, "c1"]),), |
| ), |
| ( |
| "sub_main_2", |
| (tuple([script_name, 23, "c1"]),), |
| ), |
| ], |
| "sub_main_1": [ |
| ( |
| "Task-1", |
| ( |
| tuple( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup._aexit", |
| ] |
| ), |
| tuple( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup.__aexit__", |
| ] |
| ), |
| tuple([script_name, 26, "main"]), |
| ), |
| ) |
| ], |
| "sub_main_2": [ |
| ( |
| "Task-1", |
| ( |
| tuple( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup._aexit", |
| ] |
| ), |
| tuple( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup.__aexit__", |
| ] |
| ), |
| tuple([script_name, 26, "main"]), |
| ), |
| ) |
| ], |
| }, |
| ) |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_asyncgen_remote_stack_trace(self): |
| port = find_unused_port() |
| script = textwrap.dedent( |
| f"""\ |
| import asyncio |
| import time |
| import sys |
| import socket |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| async def gen_nested_call(): |
| sock.sendall(b"ready"); time.sleep(10_000) |
| |
| async def gen(): |
| for num in range(2): |
| yield num |
| if num == 1: |
| await gen_nested_call() |
| |
| async def main(): |
| async for el in gen(): |
| pass |
| |
| asyncio.run(main()) |
| """ |
| ) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| response = _wait_for_signal(client_socket, b"ready") |
| self.assertIn(b"ready", response) |
| |
| try: |
| stack_trace = get_async_stack_trace(p.pid) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| # For this simple asyncgen test, we only expect one task |
| self.assertEqual(len(stack_trace[0].awaited_by), 1) |
| task = stack_trace[0].awaited_by[0] |
| self.assertEqual(task.task_name, "Task-1") |
| |
| # Check the coroutine stack |
| coroutine_stack = sorted( |
| tuple(tuple(frame) for frame in coro.call_stack) |
| for coro in task.coroutine_stack |
| ) |
| self.assertEqual( |
| coroutine_stack, |
| [ |
| ( |
| tuple([script_name, 10, "gen_nested_call"]), |
| tuple([script_name, 16, "gen"]), |
| tuple([script_name, 19, "main"]), |
| ) |
| ], |
| ) |
| |
| # No awaited_by relationships expected |
| self.assertEqual(task.awaited_by, []) |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_async_gather_remote_stack_trace(self): |
| port = find_unused_port() |
| script = textwrap.dedent( |
| f"""\ |
| import asyncio |
| import time |
| import sys |
| import socket |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| async def deep(): |
| await asyncio.sleep(0) |
| sock.sendall(b"ready"); time.sleep(10_000) |
| |
| async def c1(): |
| await asyncio.sleep(0) |
| await deep() |
| |
| async def c2(): |
| await asyncio.sleep(0) |
| |
| async def main(): |
| await asyncio.gather(c1(), c2()) |
| |
| asyncio.run(main()) |
| """ |
| ) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| response = _wait_for_signal(client_socket, b"ready") |
| self.assertIn(b"ready", response) |
| |
| try: |
| stack_trace = get_async_stack_trace(p.pid) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| # Check all tasks are present |
| tasks_names = [ |
| task.task_name for task in stack_trace[0].awaited_by |
| ] |
| for task_name in ["Task-1", "Task-2"]: |
| self.assertIn(task_name, tasks_names) |
| |
| # Check awaited_by relationships |
| relationships = self._get_awaited_by_relationships( |
| stack_trace |
| ) |
| self.assertEqual( |
| relationships, |
| { |
| "Task-1": set(), |
| "Task-2": {"Task-1"}, |
| }, |
| ) |
| |
| # Check coroutine stacks |
| coroutine_stacks = self._extract_coroutine_stacks( |
| stack_trace |
| ) |
| self.assertEqual( |
| coroutine_stacks, |
| { |
| "Task-1": [(tuple([script_name, 21, "main"]),)], |
| "Task-2": [ |
| ( |
| tuple([script_name, 11, "deep"]), |
| tuple([script_name, 15, "c1"]), |
| ) |
| ], |
| }, |
| ) |
| |
| # Check awaited_by coroutine stacks |
| id_to_task = self._get_task_id_map(stack_trace) |
| awaited_by_coroutine_stacks = { |
| task.task_name: sorted( |
| ( |
| id_to_task[coro.task_name].task_name, |
| tuple( |
| tuple(frame) for frame in coro.call_stack |
| ), |
| ) |
| for coro in task.awaited_by |
| ) |
| for task in stack_trace[0].awaited_by |
| } |
| self.assertEqual( |
| awaited_by_coroutine_stacks, |
| { |
| "Task-1": [], |
| "Task-2": [ |
| ("Task-1", (tuple([script_name, 21, "main"]),)) |
| ], |
| }, |
| ) |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_async_staggered_race_remote_stack_trace(self): |
| port = find_unused_port() |
| script = textwrap.dedent( |
| f"""\ |
| import asyncio.staggered |
| import time |
| import sys |
| import socket |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| async def deep(): |
| await asyncio.sleep(0) |
| sock.sendall(b"ready"); time.sleep(10_000) |
| |
| async def c1(): |
| await asyncio.sleep(0) |
| await deep() |
| |
| async def c2(): |
| await asyncio.sleep(10_000) |
| |
| async def main(): |
| await asyncio.staggered.staggered_race( |
| [c1, c2], |
| delay=None, |
| ) |
| |
| asyncio.run(main()) |
| """ |
| ) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| response = _wait_for_signal(client_socket, b"ready") |
| self.assertIn(b"ready", response) |
| |
| try: |
| stack_trace = get_async_stack_trace(p.pid) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| # Check all tasks are present |
| tasks_names = [ |
| task.task_name for task in stack_trace[0].awaited_by |
| ] |
| for task_name in ["Task-1", "Task-2"]: |
| self.assertIn(task_name, tasks_names) |
| |
| # Check awaited_by relationships |
| relationships = self._get_awaited_by_relationships( |
| stack_trace |
| ) |
| self.assertEqual( |
| relationships, |
| { |
| "Task-1": set(), |
| "Task-2": {"Task-1"}, |
| }, |
| ) |
| |
| # Check coroutine stacks |
| coroutine_stacks = self._extract_coroutine_stacks( |
| stack_trace |
| ) |
| self.assertEqual( |
| coroutine_stacks, |
| { |
| "Task-1": [ |
| ( |
| tuple( |
| [ |
| staggered.__file__, |
| ANY, |
| "staggered_race", |
| ] |
| ), |
| tuple([script_name, 21, "main"]), |
| ) |
| ], |
| "Task-2": [ |
| ( |
| tuple([script_name, 11, "deep"]), |
| tuple([script_name, 15, "c1"]), |
| tuple( |
| [ |
| staggered.__file__, |
| ANY, |
| "staggered_race.<locals>.run_one_coro", |
| ] |
| ), |
| ) |
| ], |
| }, |
| ) |
| |
| # Check awaited_by coroutine stacks |
| id_to_task = self._get_task_id_map(stack_trace) |
| awaited_by_coroutine_stacks = { |
| task.task_name: sorted( |
| ( |
| id_to_task[coro.task_name].task_name, |
| tuple( |
| tuple(frame) for frame in coro.call_stack |
| ), |
| ) |
| for coro in task.awaited_by |
| ) |
| for task in stack_trace[0].awaited_by |
| } |
| self.assertEqual( |
| awaited_by_coroutine_stacks, |
| { |
| "Task-1": [], |
| "Task-2": [ |
| ( |
| "Task-1", |
| ( |
| tuple( |
| [ |
| staggered.__file__, |
| ANY, |
| "staggered_race", |
| ] |
| ), |
| tuple([script_name, 21, "main"]), |
| ), |
| ) |
| ], |
| }, |
| ) |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_async_global_awaited_by(self): |
| # Reduced from 1000 to 100 to avoid file descriptor exhaustion |
| # when running tests in parallel (e.g., -j 20) |
| NUM_TASKS = 100 |
| |
| port = find_unused_port() |
| script = textwrap.dedent( |
| f"""\ |
| import asyncio |
| import os |
| import random |
| import sys |
| import socket |
| from string import ascii_lowercase, digits |
| from test.support import socket_helper, SHORT_TIMEOUT |
| |
| HOST = '127.0.0.1' |
| PORT = socket_helper.find_unused_port() |
| connections = 0 |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| class EchoServerProtocol(asyncio.Protocol): |
| def connection_made(self, transport): |
| global connections |
| connections += 1 |
| self.transport = transport |
| |
| def data_received(self, data): |
| self.transport.write(data) |
| self.transport.close() |
| |
| async def echo_client(message): |
| reader, writer = await asyncio.open_connection(HOST, PORT) |
| writer.write(message.encode()) |
| await writer.drain() |
| |
| data = await reader.read(100) |
| assert message == data.decode() |
| writer.close() |
| await writer.wait_closed() |
| sock.sendall(b"ready") |
| await asyncio.sleep(SHORT_TIMEOUT) |
| |
| async def echo_client_spam(server): |
| async with asyncio.TaskGroup() as tg: |
| while connections < {NUM_TASKS}: |
| msg = list(ascii_lowercase + digits) |
| random.shuffle(msg) |
| tg.create_task(echo_client("".join(msg))) |
| await asyncio.sleep(0) |
| server.close() |
| await server.wait_closed() |
| |
| async def main(): |
| loop = asyncio.get_running_loop() |
| server = await loop.create_server(EchoServerProtocol, HOST, PORT) |
| async with server: |
| async with asyncio.TaskGroup() as tg: |
| tg.create_task(server.serve_forever(), name="server task") |
| tg.create_task(echo_client_spam(server), name="echo client spam") |
| |
| asyncio.run(main()) |
| """ |
| ) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| # Wait for NUM_TASKS "ready" signals |
| try: |
| _wait_for_n_signals(client_socket, b"ready", NUM_TASKS) |
| except RuntimeError as e: |
| self.fail(str(e)) |
| |
| try: |
| all_awaited_by = get_all_awaited_by(p.pid) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| # Expected: a list of two elements: 1 thread, 1 interp |
| self.assertEqual(len(all_awaited_by), 2) |
| # Expected: a tuple with the thread ID and the awaited_by list |
| self.assertEqual(len(all_awaited_by[0]), 2) |
| # Expected: no tasks in the fallback per-interp task list |
| self.assertEqual(all_awaited_by[1], (0, [])) |
| |
| entries = all_awaited_by[0][1] |
| # Expected: at least NUM_TASKS pending tasks |
| self.assertGreaterEqual(len(entries), NUM_TASKS) |
| |
| # Check the main task structure |
| main_stack = [ |
| FrameInfo( |
| [taskgroups.__file__, ANY, "TaskGroup._aexit"] |
| ), |
| FrameInfo( |
| [taskgroups.__file__, ANY, "TaskGroup.__aexit__"] |
| ), |
| FrameInfo([script_name, 52, "main"]), |
| ] |
| self.assertIn( |
| TaskInfo( |
| [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []] |
| ), |
| entries, |
| ) |
| self.assertIn( |
| TaskInfo( |
| [ |
| ANY, |
| "server task", |
| [ |
| CoroInfo( |
| [ |
| [ |
| FrameInfo( |
| [ |
| base_events.__file__, |
| ANY, |
| "Server.serve_forever", |
| ] |
| ) |
| ], |
| ANY, |
| ] |
| ) |
| ], |
| [ |
| CoroInfo( |
| [ |
| [ |
| FrameInfo( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup._aexit", |
| ] |
| ), |
| FrameInfo( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup.__aexit__", |
| ] |
| ), |
| FrameInfo( |
| [script_name, ANY, "main"] |
| ), |
| ], |
| ANY, |
| ] |
| ) |
| ], |
| ] |
| ), |
| entries, |
| ) |
| self.assertIn( |
| TaskInfo( |
| [ |
| ANY, |
| "Task-4", |
| [ |
| CoroInfo( |
| [ |
| [ |
| FrameInfo( |
| [ |
| tasks.__file__, |
| ANY, |
| "sleep", |
| ] |
| ), |
| FrameInfo( |
| [ |
| script_name, |
| 36, |
| "echo_client", |
| ] |
| ), |
| ], |
| ANY, |
| ] |
| ) |
| ], |
| [ |
| CoroInfo( |
| [ |
| [ |
| FrameInfo( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup._aexit", |
| ] |
| ), |
| FrameInfo( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup.__aexit__", |
| ] |
| ), |
| FrameInfo( |
| [ |
| script_name, |
| 39, |
| "echo_client_spam", |
| ] |
| ), |
| ], |
| ANY, |
| ] |
| ) |
| ], |
| ] |
| ), |
| entries, |
| ) |
| |
| expected_awaited_by = [ |
| CoroInfo( |
| [ |
| [ |
| FrameInfo( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup._aexit", |
| ] |
| ), |
| FrameInfo( |
| [ |
| taskgroups.__file__, |
| ANY, |
| "TaskGroup.__aexit__", |
| ] |
| ), |
| FrameInfo( |
| [script_name, 39, "echo_client_spam"] |
| ), |
| ], |
| ANY, |
| ] |
| ) |
| ] |
| tasks_with_awaited = [ |
| task |
| for task in entries |
| if task.awaited_by == expected_awaited_by |
| ] |
| self.assertGreaterEqual(len(tasks_with_awaited), NUM_TASKS) |
| |
| # Final task should be from echo client spam (not on Windows) |
| if sys.platform != "win32": |
| self.assertEqual( |
| tasks_with_awaited[-1].awaited_by, |
| entries[-1].awaited_by, |
| ) |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_self_trace(self): |
| stack_trace = get_stack_trace(os.getpid()) |
| |
| this_thread_stack = None |
| for interpreter_info in stack_trace: |
| for thread_info in interpreter_info.threads: |
| if thread_info.thread_id == threading.get_native_id(): |
| this_thread_stack = thread_info.frame_info |
| break |
| if this_thread_stack: |
| break |
| |
| self.assertIsNotNone(this_thread_stack) |
| self.assertEqual( |
| this_thread_stack[:2], |
| [ |
| FrameInfo( |
| [ |
| __file__, |
| get_stack_trace.__code__.co_firstlineno + 4, |
| "get_stack_trace", |
| ] |
| ), |
| FrameInfo( |
| [ |
| __file__, |
| self.test_self_trace.__code__.co_firstlineno + 6, |
| "TestGetStackTrace.test_self_trace", |
| ] |
| ), |
| ], |
| ) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| @requires_subinterpreters |
| def test_subinterpreter_stack_trace(self): |
| port = find_unused_port() |
| |
| import pickle |
| |
| subinterp_code = textwrap.dedent(f""" |
| import socket |
| import time |
| |
| def sub_worker(): |
| def nested_func(): |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| sock.sendall(b"ready:sub\\n") |
| time.sleep(10_000) |
| nested_func() |
| |
| sub_worker() |
| """).strip() |
| |
| pickled_code = pickle.dumps(subinterp_code) |
| |
| script = textwrap.dedent( |
| f""" |
| from concurrent import interpreters |
| import time |
| import sys |
| import socket |
| import threading |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| def main_worker(): |
| sock.sendall(b"ready:main\\n") |
| time.sleep(10_000) |
| |
| def run_subinterp(): |
| subinterp = interpreters.create() |
| import pickle |
| pickled_code = {pickled_code!r} |
| subinterp_code = pickle.loads(pickled_code) |
| subinterp.exec(subinterp_code) |
| |
| sub_thread = threading.Thread(target=run_subinterp) |
| sub_thread.start() |
| |
| main_thread = threading.Thread(target=main_worker) |
| main_thread.start() |
| |
| main_thread.join() |
| sub_thread.join() |
| """ |
| ) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_sockets = [] |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| # Accept connections from both main and subinterpreter |
| responses = set() |
| while len(responses) < 2: |
| try: |
| client_socket, _ = server_socket.accept() |
| client_sockets.append(client_socket) |
| response = client_socket.recv(1024) |
| if b"ready:main" in response: |
| responses.add("main") |
| if b"ready:sub" in response: |
| responses.add("sub") |
| except socket.timeout: |
| break |
| |
| server_socket.close() |
| server_socket = None |
| |
| try: |
| stack_trace = get_stack_trace(p.pid) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| # Verify we have at least one interpreter |
| self.assertGreaterEqual(len(stack_trace), 1) |
| |
| # Look for main interpreter (ID 0) and subinterpreter (ID > 0) |
| main_interp = None |
| sub_interp = None |
| for interpreter_info in stack_trace: |
| if interpreter_info.interpreter_id == 0: |
| main_interp = interpreter_info |
| elif interpreter_info.interpreter_id > 0: |
| sub_interp = interpreter_info |
| |
| self.assertIsNotNone( |
| main_interp, "Main interpreter should be present" |
| ) |
| |
| # Check main interpreter has expected stack trace |
| main_found = self._find_frame_in_trace( |
| [main_interp], lambda f: f.funcname == "main_worker" |
| ) |
| self.assertIsNotNone( |
| main_found, |
| "Main interpreter should have main_worker in stack", |
| ) |
| |
| # If subinterpreter is present, check its stack trace |
| if sub_interp: |
| sub_found = self._find_frame_in_trace( |
| [sub_interp], |
| lambda f: f.funcname |
| in ("sub_worker", "nested_func"), |
| ) |
| self.assertIsNotNone( |
| sub_found, |
| "Subinterpreter should have sub_worker or nested_func in stack", |
| ) |
| finally: |
| _cleanup_sockets(*client_sockets, server_socket) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| @requires_subinterpreters |
| def test_multiple_subinterpreters_with_threads(self): |
| port = find_unused_port() |
| |
| import pickle |
| |
| subinterp1_code = textwrap.dedent(f""" |
| import socket |
| import time |
| import threading |
| |
| def worker1(): |
| def nested_func(): |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| sock.sendall(b"ready:sub1-t1\\n") |
| time.sleep(10_000) |
| nested_func() |
| |
| def worker2(): |
| def nested_func(): |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| sock.sendall(b"ready:sub1-t2\\n") |
| time.sleep(10_000) |
| nested_func() |
| |
| t1 = threading.Thread(target=worker1) |
| t2 = threading.Thread(target=worker2) |
| t1.start() |
| t2.start() |
| t1.join() |
| t2.join() |
| """).strip() |
| |
| subinterp2_code = textwrap.dedent(f""" |
| import socket |
| import time |
| import threading |
| |
| def worker1(): |
| def nested_func(): |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| sock.sendall(b"ready:sub2-t1\\n") |
| time.sleep(10_000) |
| nested_func() |
| |
| def worker2(): |
| def nested_func(): |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| sock.sendall(b"ready:sub2-t2\\n") |
| time.sleep(10_000) |
| nested_func() |
| |
| t1 = threading.Thread(target=worker1) |
| t2 = threading.Thread(target=worker2) |
| t1.start() |
| t2.start() |
| t1.join() |
| t2.join() |
| """).strip() |
| |
| pickled_code1 = pickle.dumps(subinterp1_code) |
| pickled_code2 = pickle.dumps(subinterp2_code) |
| |
| script = textwrap.dedent( |
| f""" |
| from concurrent import interpreters |
| import time |
| import sys |
| import socket |
| import threading |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| def main_worker(): |
| sock.sendall(b"ready:main\\n") |
| time.sleep(10_000) |
| |
| def run_subinterp1(): |
| subinterp = interpreters.create() |
| import pickle |
| pickled_code = {pickled_code1!r} |
| subinterp_code = pickle.loads(pickled_code) |
| subinterp.exec(subinterp_code) |
| |
| def run_subinterp2(): |
| subinterp = interpreters.create() |
| import pickle |
| pickled_code = {pickled_code2!r} |
| subinterp_code = pickle.loads(pickled_code) |
| subinterp.exec(subinterp_code) |
| |
| sub1_thread = threading.Thread(target=run_subinterp1) |
| sub2_thread = threading.Thread(target=run_subinterp2) |
| sub1_thread.start() |
| sub2_thread.start() |
| |
| main_thread = threading.Thread(target=main_worker) |
| main_thread.start() |
| |
| main_thread.join() |
| sub1_thread.join() |
| sub2_thread.join() |
| """ |
| ) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port, backlog=5) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_sockets = [] |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| # Accept connections from main and all subinterpreter threads |
| expected_responses = { |
| "ready:main", |
| "ready:sub1-t1", |
| "ready:sub1-t2", |
| "ready:sub2-t1", |
| "ready:sub2-t2", |
| } |
| responses = set() |
| |
| while len(responses) < 5: |
| try: |
| client_socket, _ = server_socket.accept() |
| client_sockets.append(client_socket) |
| response = client_socket.recv(1024) |
| response_str = response.decode().strip() |
| if response_str in expected_responses: |
| responses.add(response_str) |
| except socket.timeout: |
| break |
| |
| server_socket.close() |
| server_socket = None |
| |
| try: |
| stack_trace = get_stack_trace(p.pid) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| # Verify we have multiple interpreters |
| self.assertGreaterEqual(len(stack_trace), 2) |
| |
| # Count interpreters by ID |
| interpreter_ids = { |
| interp.interpreter_id for interp in stack_trace |
| } |
| self.assertIn( |
| 0, |
| interpreter_ids, |
| "Main interpreter should be present", |
| ) |
| self.assertGreaterEqual(len(interpreter_ids), 3) |
| |
| # Count total threads |
| total_threads = sum( |
| len(interp.threads) for interp in stack_trace |
| ) |
| self.assertGreaterEqual(total_threads, 5) |
| |
| # Look for expected function names |
| all_funcnames = set() |
| for interpreter_info in stack_trace: |
| for thread_info in interpreter_info.threads: |
| for frame in thread_info.frame_info: |
| all_funcnames.add(frame.funcname) |
| |
| expected_funcs = { |
| "main_worker", |
| "worker1", |
| "worker2", |
| "nested_func", |
| } |
| found_funcs = expected_funcs.intersection(all_funcnames) |
| self.assertGreater(len(found_funcs), 0) |
| finally: |
| _cleanup_sockets(*client_sockets, server_socket) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| @requires_gil_enabled("Free threaded builds don't have an 'active thread'") |
| def test_only_active_thread(self): |
| port = find_unused_port() |
| script = textwrap.dedent( |
| f"""\ |
| import time, sys, socket, threading |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| def worker_thread(name, barrier, ready_event): |
| barrier.wait() |
| ready_event.wait() |
| time.sleep(10_000) |
| |
| def main_work(): |
| sock.sendall(b"working\\n") |
| count = 0 |
| while count < 100000000: |
| count += 1 |
| if count % 10000000 == 0: |
| pass |
| sock.sendall(b"done\\n") |
| |
| num_threads = 3 |
| barrier = threading.Barrier(num_threads + 1) |
| ready_event = threading.Event() |
| |
| threads = [] |
| for i in range(num_threads): |
| t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event)) |
| t.start() |
| threads.append(t) |
| |
| barrier.wait() |
| sock.sendall(b"ready\\n") |
| ready_event.set() |
| main_work() |
| """ |
| ) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| # Wait for ready and working signals |
| _wait_for_signal(client_socket, [b"ready", b"working"]) |
| |
| try: |
| # Get stack trace with all threads |
| unwinder_all = RemoteUnwinder(p.pid, all_threads=True) |
| for _ in range(MAX_TRIES): |
| all_traces = unwinder_all.get_stack_trace() |
| found = self._find_frame_in_trace( |
| all_traces, |
| lambda f: f.funcname == "main_work" |
| and f.lineno > 12, |
| ) |
| if found: |
| break |
| time.sleep(0.1) |
| else: |
| self.fail( |
| "Main thread did not start its busy work on time" |
| ) |
| |
| # Get stack trace with only GIL holder |
| unwinder_gil = RemoteUnwinder( |
| p.pid, only_active_thread=True |
| ) |
| gil_traces = unwinder_gil.get_stack_trace() |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| # Count threads |
| total_threads = sum( |
| len(interp.threads) for interp in all_traces |
| ) |
| self.assertGreater(total_threads, 1) |
| |
| total_gil_threads = sum( |
| len(interp.threads) for interp in gil_traces |
| ) |
| self.assertEqual(total_gil_threads, 1) |
| |
| # Get the GIL holder thread ID |
| gil_thread_id = None |
| for interpreter_info in gil_traces: |
| if interpreter_info.threads: |
| gil_thread_id = interpreter_info.threads[ |
| 0 |
| ].thread_id |
| break |
| |
| # Get all thread IDs |
| all_thread_ids = [] |
| for interpreter_info in all_traces: |
| for thread_info in interpreter_info.threads: |
| all_thread_ids.append(thread_info.thread_id) |
| |
| self.assertIn(gil_thread_id, all_thread_ids) |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| |
| class TestUnsupportedPlatformHandling(unittest.TestCase): |
| @unittest.skipIf( |
| sys.platform in ("linux", "darwin", "win32"), |
| "Test only runs on unsupported platforms (not Linux, macOS, or Windows)", |
| ) |
| @unittest.skipIf( |
| sys.platform == "android", "Android raises Linux-specific exception" |
| ) |
| def test_unsupported_platform_error(self): |
| with self.assertRaises(RuntimeError) as cm: |
| RemoteUnwinder(os.getpid()) |
| |
| self.assertIn( |
| "Reading the PyRuntime section is not supported on this platform", |
| str(cm.exception), |
| ) |
| |
| |
| class TestDetectionOfThreadStatus(RemoteInspectionTestBase): |
| def _run_thread_status_test(self, mode, check_condition): |
| """ |
| Common pattern for thread status detection tests. |
| |
| Args: |
| mode: Profiling mode (PROFILING_MODE_CPU, PROFILING_MODE_GIL, etc.) |
| check_condition: Function(statuses, sleeper_tid, busy_tid) -> bool |
| """ |
| port = find_unused_port() |
| script = textwrap.dedent( |
| f"""\ |
| import time, sys, socket, threading |
| import os |
| |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| |
| def sleeper(): |
| tid = threading.get_native_id() |
| sock.sendall(f'ready:sleeper:{{tid}}\\n'.encode()) |
| time.sleep(10000) |
| |
| def busy(): |
| tid = threading.get_native_id() |
| sock.sendall(f'ready:busy:{{tid}}\\n'.encode()) |
| x = 0 |
| while True: |
| x = x + 1 |
| time.sleep(0.5) |
| |
| t1 = threading.Thread(target=sleeper) |
| t2 = threading.Thread(target=busy) |
| t1.start() |
| t2.start() |
| sock.sendall(b'ready:main\\n') |
| t1.join() |
| t2.join() |
| sock.close() |
| """ |
| ) |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script( |
| script_dir, "thread_status_script", script |
| ) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| # Wait for all ready signals and parse TIDs |
| response = _wait_for_signal( |
| client_socket, |
| [b"ready:main", b"ready:sleeper", b"ready:busy"], |
| ) |
| |
| sleeper_tid = None |
| busy_tid = None |
| for line in response.split(b"\n"): |
| if line.startswith(b"ready:sleeper:"): |
| try: |
| sleeper_tid = int(line.split(b":")[-1]) |
| except (ValueError, IndexError): |
| pass |
| elif line.startswith(b"ready:busy:"): |
| try: |
| busy_tid = int(line.split(b":")[-1]) |
| except (ValueError, IndexError): |
| pass |
| |
| self.assertIsNotNone( |
| sleeper_tid, "Sleeper thread id not received" |
| ) |
| self.assertIsNotNone( |
| busy_tid, "Busy thread id not received" |
| ) |
| |
| # Sample until we see expected thread states |
| statuses = {} |
| try: |
| unwinder = RemoteUnwinder( |
| p.pid, |
| all_threads=True, |
| mode=mode, |
| skip_non_matching_threads=False, |
| ) |
| for _ in range(MAX_TRIES): |
| traces = unwinder.get_stack_trace() |
| statuses = self._get_thread_statuses(traces) |
| |
| if check_condition( |
| statuses, sleeper_tid, busy_tid |
| ): |
| break |
| time.sleep(0.5) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| return statuses, sleeper_tid, busy_tid |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| @unittest.skipIf( |
| sys.platform not in ("linux", "darwin", "win32"), |
| "Test only runs on supported platforms (Linux, macOS, or Windows)", |
| ) |
| @unittest.skipIf( |
| sys.platform == "android", "Android raises Linux-specific exception" |
| ) |
| def test_thread_status_detection(self): |
| def check_cpu_status(statuses, sleeper_tid, busy_tid): |
| return ( |
| sleeper_tid in statuses |
| and busy_tid in statuses |
| and not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU) |
| and (statuses[busy_tid] & THREAD_STATUS_ON_CPU) |
| ) |
| |
| statuses, sleeper_tid, busy_tid = self._run_thread_status_test( |
| PROFILING_MODE_CPU, check_cpu_status |
| ) |
| |
| self.assertIn(sleeper_tid, statuses) |
| self.assertIn(busy_tid, statuses) |
| self.assertFalse( |
| statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, |
| "Sleeper thread should be off CPU", |
| ) |
| self.assertTrue( |
| statuses[busy_tid] & THREAD_STATUS_ON_CPU, |
| "Busy thread should be on CPU", |
| ) |
| |
| @unittest.skipIf( |
| sys.platform not in ("linux", "darwin", "win32"), |
| "Test only runs on supported platforms (Linux, macOS, or Windows)", |
| ) |
| @unittest.skipIf( |
| sys.platform == "android", "Android raises Linux-specific exception" |
| ) |
| def test_thread_status_gil_detection(self): |
| def check_gil_status(statuses, sleeper_tid, busy_tid): |
| return ( |
| sleeper_tid in statuses |
| and busy_tid in statuses |
| and not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL) |
| and (statuses[busy_tid] & THREAD_STATUS_HAS_GIL) |
| ) |
| |
| statuses, sleeper_tid, busy_tid = self._run_thread_status_test( |
| PROFILING_MODE_GIL, check_gil_status |
| ) |
| |
| self.assertIn(sleeper_tid, statuses) |
| self.assertIn(busy_tid, statuses) |
| self.assertFalse( |
| statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, |
| "Sleeper thread should not have GIL", |
| ) |
| self.assertTrue( |
| statuses[busy_tid] & THREAD_STATUS_HAS_GIL, |
| "Busy thread should have GIL", |
| ) |
| |
| @unittest.skipIf( |
| sys.platform not in ("linux", "darwin", "win32"), |
| "Test only runs on supported platforms (Linux, macOS, or Windows)", |
| ) |
| @unittest.skipIf( |
| sys.platform == "android", "Android raises Linux-specific exception" |
| ) |
| def test_thread_status_all_mode_detection(self): |
| port = find_unused_port() |
| script = textwrap.dedent( |
| f"""\ |
| import socket |
| import threading |
| import time |
| import sys |
| |
| def sleeper_thread(): |
| conn = socket.create_connection(("localhost", {port})) |
| conn.sendall(b"sleeper:" + str(threading.get_native_id()).encode()) |
| while True: |
| time.sleep(1) |
| |
| def busy_thread(): |
| conn = socket.create_connection(("localhost", {port})) |
| conn.sendall(b"busy:" + str(threading.get_native_id()).encode()) |
| while True: |
| sum(range(100000)) |
| |
| t1 = threading.Thread(target=sleeper_thread) |
| t2 = threading.Thread(target=busy_thread) |
| t1.start() |
| t2.start() |
| t1.join() |
| t2.join() |
| """ |
| ) |
| |
| with os_helper.temp_dir() as tmp_dir: |
| script_file = make_script(tmp_dir, "script", script) |
| server_socket = _create_server_socket(port, backlog=2) |
| client_sockets = [] |
| |
| try: |
| with _managed_subprocess( |
| [sys.executable, script_file], |
| ) as p: |
| sleeper_tid = None |
| busy_tid = None |
| |
| # Receive thread IDs from the child process |
| for _ in range(2): |
| client_socket, _ = server_socket.accept() |
| client_sockets.append(client_socket) |
| line = client_socket.recv(1024) |
| if line: |
| if line.startswith(b"sleeper:"): |
| try: |
| sleeper_tid = int(line.split(b":")[-1]) |
| except (ValueError, IndexError): |
| pass |
| elif line.startswith(b"busy:"): |
| try: |
| busy_tid = int(line.split(b":")[-1]) |
| except (ValueError, IndexError): |
| pass |
| |
| server_socket.close() |
| server_socket = None |
| |
| statuses = {} |
| try: |
| unwinder = RemoteUnwinder( |
| p.pid, |
| all_threads=True, |
| mode=PROFILING_MODE_ALL, |
| skip_non_matching_threads=False, |
| ) |
| for _ in range(MAX_TRIES): |
| traces = unwinder.get_stack_trace() |
| statuses = self._get_thread_statuses(traces) |
| |
| # Check ALL mode provides both GIL and CPU info |
| if ( |
| sleeper_tid in statuses |
| and busy_tid in statuses |
| and not ( |
| statuses[sleeper_tid] |
| & THREAD_STATUS_ON_CPU |
| ) |
| and not ( |
| statuses[sleeper_tid] |
| & THREAD_STATUS_HAS_GIL |
| ) |
| and (statuses[busy_tid] & THREAD_STATUS_ON_CPU) |
| and ( |
| statuses[busy_tid] & THREAD_STATUS_HAS_GIL |
| ) |
| ): |
| break |
| time.sleep(0.5) |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| |
| self.assertIsNotNone( |
| sleeper_tid, "Sleeper thread id not received" |
| ) |
| self.assertIsNotNone( |
| busy_tid, "Busy thread id not received" |
| ) |
| self.assertIn(sleeper_tid, statuses) |
| self.assertIn(busy_tid, statuses) |
| |
| # Sleeper: off CPU, no GIL |
| self.assertFalse( |
| statuses[sleeper_tid] & THREAD_STATUS_ON_CPU, |
| "Sleeper should be off CPU", |
| ) |
| self.assertFalse( |
| statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL, |
| "Sleeper should not have GIL", |
| ) |
| |
| # Busy: on CPU, has GIL |
| self.assertTrue( |
| statuses[busy_tid] & THREAD_STATUS_ON_CPU, |
| "Busy should be on CPU", |
| ) |
| self.assertTrue( |
| statuses[busy_tid] & THREAD_STATUS_HAS_GIL, |
| "Busy should have GIL", |
| ) |
| finally: |
| _cleanup_sockets(*client_sockets, server_socket) |
| |
| |
| class TestFrameCaching(RemoteInspectionTestBase): |
| """Test that frame caching produces correct results. |
| |
| Uses socket-based synchronization for deterministic testing. |
| All tests verify cache reuse via object identity checks (assertIs). |
| """ |
| |
| @contextmanager |
| def _target_process(self, script_body): |
| """Context manager for running a target process with socket sync.""" |
| port = find_unused_port() |
| script = f"""\ |
| import socket |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| sock.connect(('localhost', {port})) |
| {textwrap.dedent(script_body)} |
| """ |
| |
| with os_helper.temp_dir() as work_dir: |
| script_dir = os.path.join(work_dir, "script_pkg") |
| os.mkdir(script_dir) |
| |
| server_socket = _create_server_socket(port) |
| script_name = _make_test_script(script_dir, "script", script) |
| client_socket = None |
| |
| try: |
| with _managed_subprocess([sys.executable, script_name]) as p: |
| client_socket, _ = server_socket.accept() |
| server_socket.close() |
| server_socket = None |
| |
| def make_unwinder(cache_frames=True): |
| return RemoteUnwinder( |
| p.pid, all_threads=True, cache_frames=cache_frames |
| ) |
| |
| yield p, client_socket, make_unwinder |
| |
| except PermissionError: |
| self.skipTest( |
| "Insufficient permissions to read the stack trace" |
| ) |
| finally: |
| _cleanup_sockets(client_socket, server_socket) |
| |
| def _get_frames_with_retry(self, unwinder, required_funcs): |
| """Get frames containing required_funcs, with retry for transient errors.""" |
| for _ in range(MAX_TRIES): |
| try: |
| traces = unwinder.get_stack_trace() |
| for interp in traces: |
| for thread in interp.threads: |
| funcs = {f.funcname for f in thread.frame_info} |
| if required_funcs.issubset(funcs): |
| return thread.frame_info |
| except RuntimeError as e: |
| if _is_retriable_error(e): |
| pass |
| else: |
| raise |
| time.sleep(0.1) |
| return None |
| |
| def _sample_frames( |
| self, |
| client_socket, |
| unwinder, |
| wait_signal, |
| send_ack, |
| required_funcs, |
| expected_frames=1, |
| ): |
| """Wait for signal, sample frames with retry until required funcs present, send ack.""" |
| _wait_for_signal(client_socket, wait_signal) |
| frames = None |
| for _ in range(MAX_TRIES): |
| frames = self._get_frames_with_retry(unwinder, required_funcs) |
| if frames and len(frames) >= expected_frames: |
| break |
| time.sleep(0.1) |
| client_socket.sendall(send_ack) |
| return frames |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_cache_hit_same_stack(self): |
| """Test that consecutive samples reuse cached parent frame objects. |
| |
| The current frame (index 0) is always re-read from memory to get |
| updated line numbers, so it may be a different object. Parent frames |
| (index 1+) should be identical objects from cache. |
| """ |
| script_body = """\ |
| def level3(): |
| sock.sendall(b"sync1") |
| sock.recv(16) |
| sock.sendall(b"sync2") |
| sock.recv(16) |
| sock.sendall(b"sync3") |
| sock.recv(16) |
| |
| def level2(): |
| level3() |
| |
| def level1(): |
| level2() |
| |
| level1() |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| unwinder = make_unwinder(cache_frames=True) |
| expected = {"level1", "level2", "level3"} |
| |
| frames1 = self._sample_frames( |
| client_socket, unwinder, b"sync1", b"ack", expected |
| ) |
| frames2 = self._sample_frames( |
| client_socket, unwinder, b"sync2", b"ack", expected |
| ) |
| frames3 = self._sample_frames( |
| client_socket, unwinder, b"sync3", b"done", expected |
| ) |
| |
| self.assertIsNotNone(frames1) |
| self.assertIsNotNone(frames2) |
| self.assertIsNotNone(frames3) |
| self.assertEqual(len(frames1), len(frames2)) |
| self.assertEqual(len(frames2), len(frames3)) |
| |
| # Current frame (index 0) is always re-read, so check value equality |
| self.assertEqual(frames1[0].funcname, frames2[0].funcname) |
| self.assertEqual(frames2[0].funcname, frames3[0].funcname) |
| |
| # Parent frames (index 1+) must be identical objects (cache reuse) |
| for i in range(1, len(frames1)): |
| f1, f2, f3 = frames1[i], frames2[i], frames3[i] |
| self.assertIs( |
| f1, f2, f"Frame {i}: samples 1-2 must be same object" |
| ) |
| self.assertIs( |
| f2, f3, f"Frame {i}: samples 2-3 must be same object" |
| ) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_line_number_updates_in_same_frame(self): |
| """Test that line numbers are correctly updated when execution moves within a function. |
| |
| When the profiler samples at different points within the same function, |
| it must report the correct line number for each sample, not stale cached values. |
| """ |
| script_body = """\ |
| def outer(): |
| inner() |
| |
| def inner(): |
| sock.sendall(b"line_a"); sock.recv(16) |
| sock.sendall(b"line_b"); sock.recv(16) |
| sock.sendall(b"line_c"); sock.recv(16) |
| sock.sendall(b"line_d"); sock.recv(16) |
| |
| outer() |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| unwinder = make_unwinder(cache_frames=True) |
| |
| frames_a = self._sample_frames( |
| client_socket, unwinder, b"line_a", b"ack", {"inner"} |
| ) |
| frames_b = self._sample_frames( |
| client_socket, unwinder, b"line_b", b"ack", {"inner"} |
| ) |
| frames_c = self._sample_frames( |
| client_socket, unwinder, b"line_c", b"ack", {"inner"} |
| ) |
| frames_d = self._sample_frames( |
| client_socket, unwinder, b"line_d", b"done", {"inner"} |
| ) |
| |
| self.assertIsNotNone(frames_a) |
| self.assertIsNotNone(frames_b) |
| self.assertIsNotNone(frames_c) |
| self.assertIsNotNone(frames_d) |
| |
| # Get the 'inner' frame from each sample (should be index 0) |
| inner_a = frames_a[0] |
| inner_b = frames_b[0] |
| inner_c = frames_c[0] |
| inner_d = frames_d[0] |
| |
| self.assertEqual(inner_a.funcname, "inner") |
| self.assertEqual(inner_b.funcname, "inner") |
| self.assertEqual(inner_c.funcname, "inner") |
| self.assertEqual(inner_d.funcname, "inner") |
| |
| # Line numbers must be different and increasing (execution moves forward) |
| self.assertLess( |
| inner_a.lineno, inner_b.lineno, "Line B should be after line A" |
| ) |
| self.assertLess( |
| inner_b.lineno, inner_c.lineno, "Line C should be after line B" |
| ) |
| self.assertLess( |
| inner_c.lineno, inner_d.lineno, "Line D should be after line C" |
| ) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_cache_invalidation_on_return(self): |
| """Test cache invalidation when stack shrinks (function returns).""" |
| script_body = """\ |
| def inner(): |
| sock.sendall(b"at_inner") |
| sock.recv(16) |
| |
| def outer(): |
| inner() |
| sock.sendall(b"at_outer") |
| sock.recv(16) |
| |
| outer() |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| unwinder = make_unwinder(cache_frames=True) |
| |
| frames_deep = self._sample_frames( |
| client_socket, |
| unwinder, |
| b"at_inner", |
| b"ack", |
| {"inner", "outer"}, |
| ) |
| frames_shallow = self._sample_frames( |
| client_socket, unwinder, b"at_outer", b"done", {"outer"} |
| ) |
| |
| self.assertIsNotNone(frames_deep) |
| self.assertIsNotNone(frames_shallow) |
| |
| funcs_deep = [f.funcname for f in frames_deep] |
| funcs_shallow = [f.funcname for f in frames_shallow] |
| |
| self.assertIn("inner", funcs_deep) |
| self.assertIn("outer", funcs_deep) |
| self.assertNotIn("inner", funcs_shallow) |
| self.assertIn("outer", funcs_shallow) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_cache_invalidation_on_call(self): |
| """Test cache invalidation when stack grows (new function called).""" |
| script_body = """\ |
| def deeper(): |
| sock.sendall(b"at_deeper") |
| sock.recv(16) |
| |
| def middle(): |
| sock.sendall(b"at_middle") |
| sock.recv(16) |
| deeper() |
| |
| def top(): |
| middle() |
| |
| top() |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| unwinder = make_unwinder(cache_frames=True) |
| |
| frames_before = self._sample_frames( |
| client_socket, |
| unwinder, |
| b"at_middle", |
| b"ack", |
| {"middle", "top"}, |
| ) |
| frames_after = self._sample_frames( |
| client_socket, |
| unwinder, |
| b"at_deeper", |
| b"done", |
| {"deeper", "middle", "top"}, |
| ) |
| |
| self.assertIsNotNone(frames_before) |
| self.assertIsNotNone(frames_after) |
| |
| funcs_before = [f.funcname for f in frames_before] |
| funcs_after = [f.funcname for f in frames_after] |
| |
| self.assertIn("middle", funcs_before) |
| self.assertIn("top", funcs_before) |
| self.assertNotIn("deeper", funcs_before) |
| |
| self.assertIn("deeper", funcs_after) |
| self.assertIn("middle", funcs_after) |
| self.assertIn("top", funcs_after) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_partial_stack_reuse(self): |
| """Test that unchanged bottom frames are reused when top changes (A→B→C to A→B→D).""" |
| script_body = """\ |
| def func_c(): |
| sock.sendall(b"at_c") |
| sock.recv(16) |
| |
| def func_d(): |
| sock.sendall(b"at_d") |
| sock.recv(16) |
| |
| def func_b(): |
| func_c() |
| func_d() |
| |
| def func_a(): |
| func_b() |
| |
| func_a() |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| unwinder = make_unwinder(cache_frames=True) |
| |
| # Sample at C: stack is A→B→C |
| frames_c = self._sample_frames( |
| client_socket, |
| unwinder, |
| b"at_c", |
| b"ack", |
| {"func_a", "func_b", "func_c"}, |
| ) |
| # Sample at D: stack is A→B→D (C returned, D called) |
| frames_d = self._sample_frames( |
| client_socket, |
| unwinder, |
| b"at_d", |
| b"done", |
| {"func_a", "func_b", "func_d"}, |
| ) |
| |
| self.assertIsNotNone(frames_c) |
| self.assertIsNotNone(frames_d) |
| |
| # Find func_a and func_b frames in both samples |
| def find_frame(frames, funcname): |
| for f in frames: |
| if f.funcname == funcname: |
| return f |
| return None |
| |
| frame_a_in_c = find_frame(frames_c, "func_a") |
| frame_b_in_c = find_frame(frames_c, "func_b") |
| frame_a_in_d = find_frame(frames_d, "func_a") |
| frame_b_in_d = find_frame(frames_d, "func_b") |
| |
| self.assertIsNotNone(frame_a_in_c) |
| self.assertIsNotNone(frame_b_in_c) |
| self.assertIsNotNone(frame_a_in_d) |
| self.assertIsNotNone(frame_b_in_d) |
| |
| # The bottom frames (A, B) should be the SAME objects (cache reuse) |
| self.assertIs( |
| frame_a_in_c, |
| frame_a_in_d, |
| "func_a frame should be reused from cache", |
| ) |
| self.assertIs( |
| frame_b_in_c, |
| frame_b_in_d, |
| "func_b frame should be reused from cache", |
| ) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_recursive_frames(self): |
| """Test caching with same function appearing multiple times (recursion).""" |
| script_body = """\ |
| def recurse(n): |
| if n <= 0: |
| sock.sendall(b"sync1") |
| sock.recv(16) |
| sock.sendall(b"sync2") |
| sock.recv(16) |
| else: |
| recurse(n - 1) |
| |
| recurse(5) |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| unwinder = make_unwinder(cache_frames=True) |
| |
| frames1 = self._sample_frames( |
| client_socket, unwinder, b"sync1", b"ack", {"recurse"} |
| ) |
| frames2 = self._sample_frames( |
| client_socket, unwinder, b"sync2", b"done", {"recurse"} |
| ) |
| |
| self.assertIsNotNone(frames1) |
| self.assertIsNotNone(frames2) |
| |
| # Should have multiple "recurse" frames (6 total: recurse(5) down to recurse(0)) |
| recurse_count = sum(1 for f in frames1 if f.funcname == "recurse") |
| self.assertEqual(recurse_count, 6, "Should have 6 recursive frames") |
| |
| self.assertEqual(len(frames1), len(frames2)) |
| |
| # Current frame (index 0) is re-read, check value equality |
| self.assertEqual(frames1[0].funcname, frames2[0].funcname) |
| |
| # Parent frames (index 1+) should be identical objects (cache reuse) |
| for i in range(1, len(frames1)): |
| self.assertIs( |
| frames1[i], |
| frames2[i], |
| f"Frame {i}: recursive frames must be same object", |
| ) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_cache_vs_no_cache_equivalence(self): |
| """Test that cache_frames=True and cache_frames=False produce equivalent results.""" |
| script_body = """\ |
| def level3(): |
| sock.sendall(b"ready"); sock.recv(16) |
| |
| def level2(): |
| level3() |
| |
| def level1(): |
| level2() |
| |
| level1() |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| _wait_for_signal(client_socket, b"ready") |
| |
| # Sample with cache |
| unwinder_cache = make_unwinder(cache_frames=True) |
| frames_cached = self._get_frames_with_retry( |
| unwinder_cache, {"level1", "level2", "level3"} |
| ) |
| |
| # Sample without cache |
| unwinder_no_cache = make_unwinder(cache_frames=False) |
| frames_no_cache = self._get_frames_with_retry( |
| unwinder_no_cache, {"level1", "level2", "level3"} |
| ) |
| |
| client_socket.sendall(b"done") |
| |
| self.assertIsNotNone(frames_cached) |
| self.assertIsNotNone(frames_no_cache) |
| |
| # Same number of frames |
| self.assertEqual(len(frames_cached), len(frames_no_cache)) |
| |
| # Same function names in same order |
| funcs_cached = [f.funcname for f in frames_cached] |
| funcs_no_cache = [f.funcname for f in frames_no_cache] |
| self.assertEqual(funcs_cached, funcs_no_cache) |
| |
| # Same line numbers |
| lines_cached = [f.lineno for f in frames_cached] |
| lines_no_cache = [f.lineno for f in frames_no_cache] |
| self.assertEqual(lines_cached, lines_no_cache) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_cache_per_thread_isolation(self): |
| """Test that frame cache is per-thread and cache invalidation works independently.""" |
| script_body = """\ |
| import threading |
| |
| lock = threading.Lock() |
| |
| def sync(msg): |
| with lock: |
| sock.sendall(msg + b"\\n") |
| sock.recv(1) |
| |
| # Thread 1 functions |
| def baz1(): |
| sync(b"t1:baz1") |
| |
| def bar1(): |
| baz1() |
| |
| def blech1(): |
| sync(b"t1:blech1") |
| |
| def foo1(): |
| bar1() # Goes down to baz1, syncs |
| blech1() # Returns up, goes down to blech1, syncs |
| |
| # Thread 2 functions |
| def baz2(): |
| sync(b"t2:baz2") |
| |
| def bar2(): |
| baz2() |
| |
| def blech2(): |
| sync(b"t2:blech2") |
| |
| def foo2(): |
| bar2() # Goes down to baz2, syncs |
| blech2() # Returns up, goes down to blech2, syncs |
| |
| t1 = threading.Thread(target=foo1) |
| t2 = threading.Thread(target=foo2) |
| t1.start() |
| t2.start() |
| t1.join() |
| t2.join() |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| unwinder = make_unwinder(cache_frames=True) |
| buffer = b"" |
| |
| def recv_msg(): |
| """Receive a single message from socket.""" |
| nonlocal buffer |
| while b"\n" not in buffer: |
| chunk = client_socket.recv(256) |
| if not chunk: |
| return None |
| buffer += chunk |
| msg, buffer = buffer.split(b"\n", 1) |
| return msg |
| |
| def get_thread_frames(target_funcs): |
| """Get frames for thread matching target functions.""" |
| retries = 0 |
| for _ in busy_retry(SHORT_TIMEOUT): |
| if retries >= 5: |
| break |
| retries += 1 |
| # On Windows, ReadProcessMemory can fail with OSError |
| # (WinError 299) when frame pointers are in flux |
| with contextlib.suppress(RuntimeError, OSError): |
| traces = unwinder.get_stack_trace() |
| for interp in traces: |
| for thread in interp.threads: |
| funcs = [f.funcname for f in thread.frame_info] |
| if any(f in funcs for f in target_funcs): |
| return funcs |
| return None |
| |
| # Track results for each sync point |
| results = {} |
| |
| # Process 4 sync points: baz1, baz2, blech1, blech2 |
| # With the lock, threads are serialized - handle one at a time |
| for _ in range(4): |
| msg = recv_msg() |
| self.assertIsNotNone(msg, "Expected message from subprocess") |
| |
| # Determine which thread/function and take snapshot |
| if msg == b"t1:baz1": |
| funcs = get_thread_frames(["baz1", "bar1", "foo1"]) |
| self.assertIsNotNone(funcs, "Thread 1 not found at baz1") |
| results["t1:baz1"] = funcs |
| elif msg == b"t2:baz2": |
| funcs = get_thread_frames(["baz2", "bar2", "foo2"]) |
| self.assertIsNotNone(funcs, "Thread 2 not found at baz2") |
| results["t2:baz2"] = funcs |
| elif msg == b"t1:blech1": |
| funcs = get_thread_frames(["blech1", "foo1"]) |
| self.assertIsNotNone(funcs, "Thread 1 not found at blech1") |
| results["t1:blech1"] = funcs |
| elif msg == b"t2:blech2": |
| funcs = get_thread_frames(["blech2", "foo2"]) |
| self.assertIsNotNone(funcs, "Thread 2 not found at blech2") |
| results["t2:blech2"] = funcs |
| |
| # Release thread to continue |
| client_socket.sendall(b"k") |
| |
| # Validate Phase 1: baz snapshots |
| t1_baz = results.get("t1:baz1") |
| t2_baz = results.get("t2:baz2") |
| self.assertIsNotNone(t1_baz, "Missing t1:baz1 snapshot") |
| self.assertIsNotNone(t2_baz, "Missing t2:baz2 snapshot") |
| |
| # Thread 1 at baz1: should have foo1->bar1->baz1 |
| self.assertIn("baz1", t1_baz) |
| self.assertIn("bar1", t1_baz) |
| self.assertIn("foo1", t1_baz) |
| self.assertNotIn("blech1", t1_baz) |
| # No cross-contamination |
| self.assertNotIn("baz2", t1_baz) |
| self.assertNotIn("bar2", t1_baz) |
| self.assertNotIn("foo2", t1_baz) |
| |
| # Thread 2 at baz2: should have foo2->bar2->baz2 |
| self.assertIn("baz2", t2_baz) |
| self.assertIn("bar2", t2_baz) |
| self.assertIn("foo2", t2_baz) |
| self.assertNotIn("blech2", t2_baz) |
| # No cross-contamination |
| self.assertNotIn("baz1", t2_baz) |
| self.assertNotIn("bar1", t2_baz) |
| self.assertNotIn("foo1", t2_baz) |
| |
| # Validate Phase 2: blech snapshots (cache invalidation test) |
| t1_blech = results.get("t1:blech1") |
| t2_blech = results.get("t2:blech2") |
| self.assertIsNotNone(t1_blech, "Missing t1:blech1 snapshot") |
| self.assertIsNotNone(t2_blech, "Missing t2:blech2 snapshot") |
| |
| # Thread 1 at blech1: bar1/baz1 should be GONE (cache invalidated) |
| self.assertIn("blech1", t1_blech) |
| self.assertIn("foo1", t1_blech) |
| self.assertNotIn( |
| "bar1", t1_blech, "Cache not invalidated: bar1 still present" |
| ) |
| self.assertNotIn( |
| "baz1", t1_blech, "Cache not invalidated: baz1 still present" |
| ) |
| # No cross-contamination |
| self.assertNotIn("blech2", t1_blech) |
| |
| # Thread 2 at blech2: bar2/baz2 should be GONE (cache invalidated) |
| self.assertIn("blech2", t2_blech) |
| self.assertIn("foo2", t2_blech) |
| self.assertNotIn( |
| "bar2", t2_blech, "Cache not invalidated: bar2 still present" |
| ) |
| self.assertNotIn( |
| "baz2", t2_blech, "Cache not invalidated: baz2 still present" |
| ) |
| # No cross-contamination |
| self.assertNotIn("blech1", t2_blech) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_new_unwinder_with_stale_last_profiled_frame(self): |
| """Test that a new unwinder returns complete stack when cache lookup misses.""" |
| script_body = """\ |
| def level4(): |
| sock.sendall(b"sync1") |
| sock.recv(16) |
| sock.sendall(b"sync2") |
| sock.recv(16) |
| |
| def level3(): |
| level4() |
| |
| def level2(): |
| level3() |
| |
| def level1(): |
| level2() |
| |
| level1() |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| expected = {"level1", "level2", "level3", "level4"} |
| |
| # First unwinder samples - this sets last_profiled_frame in target |
| unwinder1 = make_unwinder(cache_frames=True) |
| frames1 = self._sample_frames( |
| client_socket, unwinder1, b"sync1", b"ack", expected |
| ) |
| |
| # Create NEW unwinder (empty cache) and sample |
| # The target still has last_profiled_frame set from unwinder1 |
| unwinder2 = make_unwinder(cache_frames=True) |
| frames2 = self._sample_frames( |
| client_socket, unwinder2, b"sync2", b"done", expected |
| ) |
| |
| self.assertIsNotNone(frames1) |
| self.assertIsNotNone(frames2) |
| |
| funcs1 = [f.funcname for f in frames1] |
| funcs2 = [f.funcname for f in frames2] |
| |
| # Both should have all levels |
| for level in ["level1", "level2", "level3", "level4"]: |
| self.assertIn(level, funcs1, f"{level} missing from first sample") |
| self.assertIn(level, funcs2, f"{level} missing from second sample") |
| |
| # Should have same stack depth |
| self.assertEqual( |
| len(frames1), |
| len(frames2), |
| "New unwinder should return complete stack despite stale last_profiled_frame", |
| ) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_cache_exhaustion(self): |
| """Test cache works when frame limit (1024) is exceeded. |
| |
| FRAME_CACHE_MAX_FRAMES=1024. With 1100 recursive frames, |
| the cache can't store all of them but should still work. |
| """ |
| # Use 1100 to exceed FRAME_CACHE_MAX_FRAMES=1024 |
| depth = 1100 |
| script_body = f"""\ |
| import sys |
| sys.setrecursionlimit(2000) |
| |
| def recurse(n): |
| if n <= 0: |
| sock.sendall(b"ready") |
| sock.recv(16) # wait for ack |
| sock.sendall(b"ready2") |
| sock.recv(16) # wait for done |
| return |
| recurse(n - 1) |
| |
| recurse({depth}) |
| """ |
| |
| with self._target_process(script_body) as ( |
| p, |
| client_socket, |
| make_unwinder, |
| ): |
| unwinder_cache = make_unwinder(cache_frames=True) |
| unwinder_no_cache = make_unwinder(cache_frames=False) |
| |
| frames_cached = self._sample_frames( |
| client_socket, |
| unwinder_cache, |
| b"ready", |
| b"ack", |
| {"recurse"}, |
| expected_frames=1102, |
| ) |
| # Sample again with no cache for comparison |
| frames_no_cache = self._sample_frames( |
| client_socket, |
| unwinder_no_cache, |
| b"ready2", |
| b"done", |
| {"recurse"}, |
| expected_frames=1102, |
| ) |
| |
| self.assertIsNotNone(frames_cached) |
| self.assertIsNotNone(frames_no_cache) |
| |
| # Both should have many recurse frames (> 1024 limit) |
| cached_count = [f.funcname for f in frames_cached].count("recurse") |
| no_cache_count = [f.funcname for f in frames_no_cache].count("recurse") |
| |
| self.assertGreater( |
| cached_count, 1000, "Should have >1000 recurse frames" |
| ) |
| self.assertGreater( |
| no_cache_count, 1000, "Should have >1000 recurse frames" |
| ) |
| |
| # Both modes should produce same frame count |
| self.assertEqual( |
| len(frames_cached), |
| len(frames_no_cache), |
| "Cache exhaustion should not affect stack completeness", |
| ) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_get_stats(self): |
| """Test that get_stats() returns statistics when stats=True.""" |
| script_body = """\ |
| sock.sendall(b"ready") |
| sock.recv(16) |
| """ |
| |
| with self._target_process(script_body) as (p, client_socket, _): |
| unwinder = RemoteUnwinder(p.pid, all_threads=True, stats=True) |
| _wait_for_signal(client_socket, b"ready") |
| |
| # Take a sample |
| unwinder.get_stack_trace() |
| |
| stats = unwinder.get_stats() |
| client_socket.sendall(b"done") |
| |
| # Verify expected keys exist |
| expected_keys = [ |
| "total_samples", |
| "frame_cache_hits", |
| "frame_cache_misses", |
| "frame_cache_partial_hits", |
| "frames_read_from_cache", |
| "frames_read_from_memory", |
| "frame_cache_hit_rate", |
| ] |
| for key in expected_keys: |
| self.assertIn(key, stats) |
| |
| self.assertEqual(stats["total_samples"], 1) |
| |
| @skip_if_not_supported |
| @unittest.skipIf( |
| sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED, |
| "Test only runs on Linux with process_vm_readv support", |
| ) |
| def test_get_stats_disabled_raises(self): |
| """Test that get_stats() raises RuntimeError when stats=False.""" |
| script_body = """\ |
| sock.sendall(b"ready") |
| sock.recv(16) |
| """ |
| |
| with self._target_process(script_body) as (p, client_socket, _): |
| unwinder = RemoteUnwinder( |
| p.pid, all_threads=True |
| ) # stats=False by default |
| _wait_for_signal(client_socket, b"ready") |
| |
| with self.assertRaises(RuntimeError): |
| unwinder.get_stats() |
| |
| client_socket.sendall(b"done") |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |