8000 gh-142654: show the clear error message when sampling on an unknown PID by kemingy · Pull Request #142655 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content
Merged
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Lib/profiling/sampling/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"""

from .cli import main
from .errors import SamplingUnknownProcessError, SamplingModuleNotFoundError, SamplingScriptNotFoundError

def handle_permission_error():
"""Handle PermissionError by displaying appropriate error message."""
Expand All @@ -64,3 +65,9 @@ def handle_permission_error():
main()
except PermissionError:
handle_permission_error()
except SamplingUnknownProcessError as err:
print(f"Tachyon cannot find the process: {err}", file=sys.stderr)
sys.exit(1)
except (SamplingModuleNotFoundError, SamplingScriptNotFoundError) as err:
print(f"Tachyon cannot find the target: {err}", file=sys.stderr)
sys.exit(1)
9 changes: 6 additions & 3 deletions Lib/profiling/sampling/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import time
from contextlib import nullcontext

from .sample import sample, sample_live
from .errors import SamplingUnknownProcessError, SamplingModuleNotFoundError, SamplingScriptNotFoundError
from .sample import sample, sample_live, _is_process_running
from .pstats_collector import PstatsCollector
from .stack_collector import CollapsedStackCollector, FlamegraphCollector
from .heatmap_collector import HeatmapCollector
Expand Down Expand Up @@ -743,6 +744,8 @@ def main():

def _handle_attach(args):
"""Handle the 'attach' command."""
if not _is_process_running(args.pid):
raise SamplingUnknownProcessError(args.pid)
# Check if live mode is requested
if args.live:
_handle_live_attach(args, args.pid)
Expand Down Expand Up @@ -792,13 +795,13 @@ def _handle_run(args):
added_cwd = True
try:
if importlib.util.find_spec(args.target) is None:
sys.exit(f"Error: Module not found: {args.target}")
raise SamplingModuleNotFoundError(args.target)
finally:
if added_cwd:
sys.path.remove(cwd)
else:
if not os.path.exists(args.target):
sys.exit(f"Error: Script not found: {args.target}")
raise SamplingScriptNotFoundError(args.target)

# Check if live mode is requested
if args.live:
Expand Down
19 changes: 19 additions & 0 deletions Lib/profiling/sampling/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Custom exceptions for the sampling profiler."""

class SamplingProfilerError(Exception):
"""Base exception for sampling profiler errors."""

class SamplingUnknownProcessError(SamplingProfilerError):
def __init__(self, pid):
self.pid = pid
super().__init__(f"Process with PID '{pid}' does not exist.")

class SamplingScriptNotFoundError(SamplingProfilerError):
def __init__(self, script_path):
self.script_path = script_path
super().__init__(f"Script '{script_path}' not found.")

class SamplingModuleNotFoundError(SamplingProfilerError):
def __init__(self, module_name):
self.module_name = module_name
super().__init__(f"Module '{module_name}' not found.")
68 changes: 40 additions & 28 deletions Lib/profiling/sampling/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,29 @@ def __init__(self, pid, sample_interval_usec, all_threads, *, mode=PROFILING_MOD
self.all_threads = all_threads
self.mode = mode # Store mode for later use
self.collect_stats = collect_stats
try:
self.unwinder = self._new_unwinder(native, gc, opcodes, skip_non_matching_threads)
except RuntimeError as err:
raise SystemExit(err) from err
# Track sample intervals and total sample count
self.sample_intervals = deque(maxlen=100)
self.total_samples = 0
self.realtime_stats = False

def _new_unwinder(self, native, gc, opcodes, skip_non_matching_threads):
if _FREE_THREADED_BUILD:
self.unwinder = _remote_debugging.RemoteUnwinder(
self.pid, all_threads=self.all_threads, mode=mode, native=native, gc=gc,
unwinder = _remote_debugging.RemoteUnwinder(
self.pid, all_threads=self.all_threads, mode=self.mode, native=native, gc=gc,
opcodes=opcodes, skip_non_matching_threads=skip_non_matching_threads,
cache_frames=True, stats=collect_stats
cache_frames=True, stats=self.collect_stats
)
else:
only_active_threads = bool(self.all_threads)
self.unwinder = _remote_debugging.RemoteUnwinder(
self.pid, only_active_thread=only_active_threads, mode=mode, native=native, gc=gc,
unwinder = _remote_debugging.RemoteUnwinder(
self.pid, only_active_thread=bool(self.all_threads), mode=self.mode, native=native, gc=gc,
opcodes=opcodes, skip_non_matching_threads=skip_non_matching_threads,
cache_frames=True, stats=collect_stats
cache_frames=True, stats=self.collect_stats
)
# Track sample intervals and total sample count
self.sample_intervals = deque(maxlen=100)
self.total_samples = 0
self.realtime_stats = False
return unwinder

def sample(self, collector, duration_sec=10, *, async_aware=False):
sample_interval_sec = self.sample_interval_usec / 1_000_000
Expand Down Expand Up @@ -86,7 +92,7 @@ def sample(self, collector, duration_sec=10, *, async_aware=False):
collector.collect_failed_sample()
errors += 1
except Exception as e:
if not self._is_process_running():
if not _is_process_running(self.pid):
break
raise e from None

Expand Down Expand Up @@ -148,22 +154,6 @@ def sample(self, collector, duration_sec=10, *, async_aware=False):
f"({(expected_samples - num_samples) / expected_samples * 100:.2f}%)"
)

def _is_process_running(self):
if sys.platform == "linux" or sys.platform == "darwin":
try:
os.kill(self.pid, 0)
return True
except ProcessLookupError:
return False
elif sys.platform == "win32":
try:
_remote_debugging.RemoteUnwinder(self.pid)
except Exception:
return False
return True
else:
raise ValueError(f"Unsupported platform: {sys.platform}")

def _print_realtime_stats(self):
"""Print real-time sampling statistics."""
if len(self.sample_intervals) < 2:
Expand Down Expand Up @@ -279,6 +269,28 @@ def _print_unwinder_stats(self):
print(f" {ANSIColors.YELLOW}Stale cache invalidations: {stale_invalidations}{ANSIColors.RESET}")


def _is_process_running(pid):
if pid <= 0:
return False
if os.name == "posix":
try:
os.kill(pid, 0)
return True
except ProcessLookupError:
return False
except PermissionError:
# EPERM means process exists but we can't signal it
return True
elif sys.platform == "win32":
try:
_remote_debugging.RemoteUnwinder(pid)
except Exception:
return False
return True
else:
raise ValueError(f"Unsupported platform: {sys.platform}")


def sample(
pid,
collector,
Expand Down
26 changes: 20 additions & 6 deletions Lib/test/test_profiling/test_sampling_profiler/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from test.support import is_emscripten, requires_remote_subprocess_debugging

from profiling.sampling.cli import main
from profiling.sampling.errors import SamplingScriptNotFoundError, SamplingModuleNotFoundError, SamplingUnknownProcessError


class TestSampleProfilerCLI(unittest.TestCase):
Expand Down Expand Up @@ -203,12 +204,12 @@ def test_cli_mutually_exclusive_pid_script(self):
with (
mock.patch("sys.argv", test_args),
mock.patch("sys.stderr", io.StringIO()) as mock_stderr,
self.assertRaises(SystemExit) as cm,
self.assertRaises(SamplingScriptNotFoundError) as cm,
):
main()

# Verify the error is about the non-existent script
self.assertIn("12345", str(cm.exception.code))
self.assertIn("12345", str(cm.exception))

def test_cli_no_target_specified(self):
# In new CLI, must specify a subcommand
Expand Down Expand Up @@ -436,6 +437,7 @@ def test_cli_default_collapsed_filename(self):

with (
mock.patch("sys.argv", test_args),
52EC mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
main()
Expand Down Expand Up @@ -475,6 +477,7 @@ def test_cli_custom_output_filenames(self):
for test_args, expected_filename, expected_format in test_cases:
with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
main()
Expand Down Expand Up @@ -513,6 +516,7 @@ def test_argument_parsing_basic(self):

with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
main()
Expand All @@ -534,6 +538,7 @@ def test_sort_options(self):

with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
main()
Expand All @@ -547,6 +552,7 @@ def test_async_aware_flag_defaults_to_running(self):

with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
main()
Expand All @@ -562,6 +568,7 @@ def test_async_aware_with_async_mode_all(self):

with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
main()
Expand All @@ -576,6 +583,7 @@ def test_async_aware_default_is_none(self):

with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
main()
Expand Down Expand Up @@ -697,14 +705,20 @@ def test_async_aware_incompatible_with_all_threads(self):
def test_run_nonexistent_script_exits_cleanly(self):
"""Test that running a non-existent script exits with a clean error."""
with mock.patch("sys.argv", ["profiling.sampling.cli", "run", "/nonexistent/script.py"]):
with self.assertRaises(SystemExit) as cm:
with self.assertRaisesRegex(SamplingScriptNotFoundError, "Script '[\\w/.]+' not found."):
main()
self.assertIn("Script not found", str(cm.exception.code))

@unittest.skipIf(is_emscripten, "subprocess not available")
def test_run_nonexistent_module_exits_cleanly(self):
"""Test that running a non-existent module exits with a clean error."""
with mock.patch("sys.argv", ["profiling.sampling.cli", "run", "-m", "nonexistent_module_xyz"]):
with self.assertRaises(SystemExit) as cm:
with self.assertRaisesRegex(SamplingModuleNotFoundError, "Module '[\\w/.]+' not found."):
main()

def test_cli_attach_nonexistent_pid(self):
fake_pid = "99999"
with mock.patch("sys.argv", ["profiling.sampling.cli", "attach", fake_pid]):
with self.assertRaises(SamplingUnknownProcessError) as cm:
main()
self.assertIn("Module not found", str(cm.exception.code))

self.assertIn(fake_pid, str(cm.exception))
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import profiling.sampling.sample
from profiling.sampling.pstats_collector import PstatsCollector
from profiling.sampling.stack_collector import CollapsedStackCollector
from profiling.sampling.sample import SampleProfiler
from profiling.sampling.sample import SampleProfiler, _is_process_running
except ImportError:
raise unittest.SkipTest(
"Test only runs when _remote_debugging is available"
Expand Down Expand Up @@ -602,7 +602,7 @@ def test_sample_target_module(self):
@requires_remote_subprocess_debugging()
class TestSampleProfilerErrorHandling(unittest.TestCase):
def test_invalid_pid(self):
with self.assertRaises((OSError, RuntimeError)):
with self.assertRaises((SystemExit, PermissionError)):
collector = PstatsCollector(sample_interval_usec=100, skip_idle=False)
profiling.sampling.sample.sample(-1, collector, duration_sec=1)

Expand Down Expand Up @@ -638,7 +638,7 @@ def test_is_process_running(self):
sample_interval_usec=1000,
all_threads=False,
)
self.assertTrue(profiler._is_process_running())
self.assertTrue(_is_process_running(profiler.pid))
self.assertIsNotNone(profiler.unwinder.get_stack_trace())
subproc.process.kill()
subproc.process.wait()
Expand All @@ -647,7 +647,7 @@ def test_is_process_running(self):
)

# Exit the context manager to ensure the process is terminated
self.assertFalse(profiler._is_process_running())
self.assertFalse(_is_process_running(profiler.pid))
self.assertRaises(
ProcessLookupError, profiler.unwinder.get_stack_trace
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def test_gil_mode_validation(self):

with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
try:
Expand Down Expand Up @@ -313,6 +314,7 @@ def test_gil_mode_cli_argument_parsing(self):

with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
try:
Expand Down Expand Up @@ -432,6 +434,7 @@ def test_exception_mode_validation(self):

with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
try:
Expand Down Expand Up @@ -493,6 +496,7 @@ def test_exception_mode_cli_argument_parsing(self):

with (
mock.patch("sys.argv", test_args),
mock.patch("profiling.sampling.cli._is_process_running", return_value=True),
mock.patch("profiling.sampling.cli.sample") as mock_sample,
):
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Show the clearer error message when using ``profiling.sampling`` on an
unknown PID.
Loading
0