diff --git a/Lib/profile/__init__.py b/Lib/profile/__init__.py new file mode 100644 index 00000000000000..21c886448ac2af --- /dev/null +++ b/Lib/profile/__init__.py @@ -0,0 +1,6 @@ +from .profile import run +from .profile import runctx +from .profile import Profile +from .profile import _Utils + +__all__ = ['run', 'runctx', 'Profile'] diff --git a/Lib/profile/__main__.py b/Lib/profile/__main__.py new file mode 100644 index 00000000000000..2e8d5b3827dccc --- /dev/null +++ b/Lib/profile/__main__.py @@ -0,0 +1,81 @@ +import io +import importlib.machinery +import os +import sys +from optparse import OptionParser + +from .profile import runctx + + +def main(): + usage = "profile.py [-o output_file_path] [-s sort] [-m module | scriptfile] [arg] ..." + parser = OptionParser(usage=usage) + parser.allow_interspersed_args = False + parser.add_option( + "-o", + "--outfile", + dest="outfile", + help="Save stats to ", + default=None, + ) + parser.add_option( + "-m", + dest="module", + action="store_true", + help="Profile a library module.", + default=False, + ) + parser.add_option( + "-s", + "--sort", + dest="sort", + help="Sort order when printing to stdout, based on pstats.Stats class", + default=-1, + ) + + if not sys.argv[1:]: + parser.print_usage() + sys.exit(2) + + (options, args) = parser.parse_args() + sys.argv[:] = args + + # The script that we're profiling may chdir, so capture the absolute path + # to the output file at startup. + if options.outfile is not None: + options.outfile = os.path.abspath(options.outfile) + + if len(args) > 0 or options.pid: + if options.module: + import runpy + + code = "run_module(modname, run_name='__main__')" + globs = {"run_module": runpy.run_module, "modname": args[0]} + else: + progname = args[0] + sys.path.insert(0, os.path.dirname(progname)) + with io.open_code(progname) as fp: + code = compile(fp.read(), progname, "exec") + spec = importlib.machinery.ModuleSpec( + name="__main__", loader=None, origin=progname + ) + globs = { + "__spec__": spec, + "__file__": spec.origin, + "__name__": spec.name, + "__package__": None, + "__cached__": None, + } + try: + runctx(code, globs, None, options.outfile, options.sort) + except BrokenPipeError as exc: + # Prevent "Exception ignored" during interpreter shutdown. + sys.stdout = None + sys.exit(exc.errno) + else: + parser.print_usage() + return parser + + +if __name__ == "__main__": + main() diff --git a/Lib/profile/collector.py b/Lib/profile/collector.py new file mode 100644 index 00000000000000..28286120aefc67 --- /dev/null +++ b/Lib/profile/collector.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + + +class Collector(ABC): + @abstractmethod + def collect(self, stack_frames): + """Collect profiling data from stack frames.""" + + @abstractmethod + def export(self, filename): + """Export collected data to a file.""" diff --git a/Lib/profile.py b/Lib/profile/profile.py similarity index 90% rename from Lib/profile.py rename to Lib/profile/profile.py index a5afb12c9d121a..1d9e2fd41f85b4 100644 --- a/Lib/profile.py +++ b/Lib/profile/profile.py @@ -550,66 +550,3 @@ def f(m, f1=f1): return mean #**************************************************************************** - -def main(): - import os - from optparse import OptionParser - - usage = "profile.py [-o output_file_path] [-s sort] [-m module | scriptfile] [arg] ..." - parser = OptionParser(usage=usage) - parser.allow_interspersed_args = False - parser.add_option('-o', '--outfile', dest="outfile", - help="Save stats to ", default=None) - parser.add_option('-m', dest="module", action="store_true", - help="Profile a library module.", default=False) - parser.add_option('-s', '--sort', dest="sort", - help="Sort order when printing to stdout, based on pstats.Stats class", - default=-1) - - if not sys.argv[1:]: - parser.print_usage() - sys.exit(2) - - (options, args) = parser.parse_args() - sys.argv[:] = args - - # The script that we're profiling may chdir, so capture the absolute path - # to the output file at startup. - if options.outfile is not None: - options.outfile = os.path.abspath(options.outfile) - - if len(args) > 0: - if options.module: - import runpy - code = "run_module(modname, run_name='__main__')" - globs = { - 'run_module': runpy.run_module, - 'modname': args[0] - } - else: - progname = args[0] - sys.path.insert(0, os.path.dirname(progname)) - with io.open_code(progname) as fp: - code = compile(fp.read(), progname, 'exec') - spec = importlib.machinery.ModuleSpec(name='__main__', loader=None, - origin=progname) - globs = { - '__spec__': spec, - '__file__': spec.origin, - '__name__': spec.name, - '__package__': None, - '__cached__': None, - } - try: - runctx(code, globs, None, options.outfile, options.sort) - except BrokenPipeError as exc: - # Prevent "Exception ignored" during interpreter shutdown. - sys.stdout = None - sys.exit(exc.errno) - else: - parser.print_usage() - return parser - -# When invoked as main program, invoke the profiler on a script -if __name__ == '__main__': - main() diff --git a/Lib/profile/pstats_collector.py b/Lib/profile/pstats_collector.py new file mode 100644 index 00000000000000..67507a6c554886 --- /dev/null +++ b/Lib/profile/pstats_collector.py @@ -0,0 +1,83 @@ +import collections +import marshal + +from .collector import Collector + + +class PstatsCollector(Collector): + def __init__(self, sample_interval_usec): + self.result = collections.defaultdict( + lambda: dict(total_calls=0, total_rec_calls=0, inline_calls=0) + ) + self.stats = {} + self.sample_interval_usec = sample_interval_usec + self.callers = collections.defaultdict( + lambda: collections.defaultdict(int) + ) + + def collect(self, stack_frames): + for thread_id, frames in stack_frames: + if not frames: + continue + + top_frame = frames[0] + top_location = ( + top_frame.filename, + top_frame.lineno, + top_frame.funcname, + ) + + self.result[top_location]["inline_calls"] += 1 + self.result[top_location]["total_calls"] += 1 + + for i in range(1, len(frames)): + callee_frame = frames[i - 1] + caller_frame = frames[i] + + callee = ( + callee_frame.filename, + callee_frame.lineno, + callee_frame.funcname, + ) + caller = ( + caller_frame.filename, + caller_frame.lineno, + caller_frame.funcname, + ) + + self.callers[callee][caller] += 1 + + if len(frames) <= 1: + continue + + for frame in frames[1:]: + location = (frame.filename, frame.lineno, frame.funcname) + self.result[location]["total_calls"] += 1 + + def export(self, filename): + self.create_stats() + self._dump_stats(filename) + + def _dump_stats(self, file): + stats_with_marker = dict(self.stats) + stats_with_marker[("__sampled__",)] = True + with open(file, "wb") as f: + marshal.dump(stats_with_marker, f) + + # Needed for compatibility with pstats.Stats + def create_stats(self): + sample_interval_sec = self.sample_interval_usec / 1_000_000 + callers = {} + for fname, call_counts in self.result.items(): + total = call_counts["inline_calls"] * sample_interval_sec + cumulative = call_counts["total_calls"] * sample_interval_sec + callers = dict(self.callers.get(fname, {})) + self.stats[fname] = ( + call_counts["total_calls"], + call_counts["total_rec_calls"] + if call_counts["total_rec_calls"] + else call_counts["total_calls"], + total, + cumulative, + callers, + ) diff --git a/Lib/profile/sample.py b/Lib/profile/sample.py new file mode 100644 index 00000000000000..011d395b1a27d7 --- /dev/null +++ b/Lib/profile/sample.py @@ -0,0 +1,407 @@ +import argparse +import _colorize +import _remote_debugging +import pstats +import time +from _colorize import ANSIColors + +from .pstats_collector import PstatsCollector +from .stack_collectors import CollapsedStackCollector + + +class SampleProfiler: + def __init__(self, pid, sample_interval_usec, all_threads): + self.pid = pid + self.sample_interval_usec = sample_interval_usec + self.all_threads = all_threads + self.unwinder = _remote_debugging.RemoteUnwinder( + self.pid, all_threads=self.all_threads + ) + + def sample(self, collector, duration_sec=10): + sample_interval_sec = self.sample_interval_usec / 1_000_000 + running_time = 0 + num_samples = 0 + errors = 0 + start_time = next_time = time.perf_counter() + while running_time < duration_sec: + if next_time < time.perf_counter(): + try: + stack_frames = self.unwinder.get_stack_trace() + collector.collect(stack_frames) + except (RuntimeError, UnicodeDecodeError, OSError): + errors += 1 + + num_samples += 1 + next_time += sample_interval_sec + + running_time = time.perf_counter() - start_time + + print(f"Captured {num_samples} samples in {running_time:.2f} seconds") + print(f"Sample rate: {num_samples / running_time:.2f} samples/sec") + print(f"Error rate: {(errors / num_samples) * 100:.2f}%") + + expected_samples = int(duration_sec / sample_interval_sec) + if num_samples < expected_samples: + print( + f"Warning: missed {expected_samples - num_samples} samples " + f"from the expected total of {expected_samples} " + f"({(expected_samples - num_samples) / expected_samples * 100:.2f}%)" + ) + + +def print_sampled_stats(stats, sort=-1, limit=None, show_summary=True): + if not isinstance(sort, tuple): + sort = (sort,) + + # Get the stats data + stats_list = [] + for func, (cc, nc, tt, ct, callers) in stats.stats.items(): + stats_list.append((func, cc, nc, tt, ct, callers)) + + # Sort based on the requested field + sort_field = sort[0] + if sort_field == -1: # stdname + stats_list.sort(key=lambda x: str(x[0])) + elif sort_field == 0: # calls + stats_list.sort(key=lambda x: x[2], reverse=True) + elif sort_field == 1: # time + stats_list.sort(key=lambda x: x[3], reverse=True) + elif sort_field == 2: # cumulative + stats_list.sort(key=lambda x: x[4], reverse=True) + elif sort_field == 3: # percall + stats_list.sort( + key=lambda x: x[3] / x[2] if x[2] > 0 else 0, reverse=True + ) + elif sort_field == 4: # cumpercall + stats_list.sort( + key=lambda x: x[4] / x[2] if x[2] > 0 else 0, reverse=True + ) + + # Apply limit if specified + if limit is not None: + stats_list = stats_list[:limit] + + # Find the maximum values for each column to determine units + max_tt = max((tt for _, _, _, tt, _, _ in stats_list), default=0) + max_ct = max((ct for _, _, _, _, ct, _ in stats_list), default=0) + + # Determine appropriate units and format strings + if max_tt >= 1.0: + tt_unit = "s" + tt_scale = 1.0 + elif max_tt >= 0.001: + tt_unit = "ms" + tt_scale = 1000.0 + else: + tt_unit = "μs" + tt_scale = 1000000.0 + + if max_ct >= 1.0: + ct_unit = "s" + ct_scale = 1.0 + elif max_ct >= 0.001: + ct_unit = "ms" + ct_scale = 1000.0 + else: + ct_unit = "μs" + ct_scale = 1000000.0 + + # Print header with colors and units + header = ( + f"{ANSIColors.BOLD_BLUE}Profile Stats:{ANSIColors.RESET}\n" + f"{ANSIColors.BOLD_BLUE}nsamples{ANSIColors.RESET} " + f"{ANSIColors.BOLD_BLUE}tottime ({tt_unit}){ANSIColors.RESET} " + f"{ANSIColors.BOLD_BLUE}persample ({tt_unit}){ANSIColors.RESET} " + f"{ANSIColors.BOLD_BLUE}cumtime ({ct_unit}){ANSIColors.RESET} " + f"{ANSIColors.BOLD_BLUE}persample ({ct_unit}){ANSIColors.RESET} " + f"{ANSIColors.BOLD_BLUE}filename:lineno(function){ANSIColors.RESET}" + ) + print(header) + + # Print each line with colors + for func, cc, nc, tt, ct, callers in stats_list: + if nc != cc: + ncalls = f"{nc}/{cc}" + else: + ncalls = str(nc) + + # Format numbers with proper alignment and precision (no colors) + tottime = f"{tt * tt_scale:8.3f}" + percall = f"{(tt / nc) * tt_scale:8.3f}" if nc > 0 else " N/A" + cumtime = f"{ct * ct_scale:8.3f}" + cumpercall = f"{(ct / nc) * ct_scale:8.3f}" if nc > 0 else " N/A" + + # Format the function name with colors + func_name = ( + f"{ANSIColors.GREEN}{func[0]}{ANSIColors.RESET}:" + f"{ANSIColors.YELLOW}{func[1]}{ANSIColors.RESET}(" + f"{ANSIColors.CYAN}{func[2]}{ANSIColors.RESET})" + ) + + # Print the formatted line + print( + f"{ncalls:>8} {tottime} {percall} {cumtime} {cumpercall} {func_name}" + ) + + def _format_func_name(func): + """Format function name with colors.""" + return ( + f"{ANSIColors.GREEN}{func[0]}{ANSIColors.RESET}:" + f"{ANSIColors.YELLOW}{func[1]}{ANSIColors.RESET}(" + f"{ANSIColors.CYAN}{func[2]}{ANSIColors.RESET})" + ) + + def _print_top_functions(stats_list, title, key_func, format_line, n=3): + """Print top N functions sorted by key_func with formatted output.""" + print(f"\n{ANSIColors.BOLD_BLUE}{title}:{ANSIColors.RESET}") + sorted_stats = sorted(stats_list, key=key_func, reverse=True) + for stat in sorted_stats[:n]: + if line := format_line(stat): + print(f" {line}") + + # Print summary of interesting functions if enabled + if show_summary and stats_list: + print( + f"\n{ANSIColors.BOLD_BLUE}Summary of Interesting Functions:{ANSIColors.RESET}" + ) + + # Most time-consuming functions (by total time) + def format_time_consuming(stat): + func, _, nc, tt, _, _ = stat + if tt > 0: + return ( + f"{tt * tt_scale:8.3f} {tt_unit} total time, " + f"{(tt / nc) * tt_scale:8.3f} {tt_unit} per call: {_format_func_name(func)}" + ) + return None + + _print_top_functions( + stats_list, + "Most Time-Consuming Functions", + key_func=lambda x: x[3], + format_line=format_time_consuming, + ) + + # Most called functions + def format_most_called(stat): + func, _, nc, tt, _, _ = stat + if nc > 0: + return ( + f"{nc:8d} calls, {(tt / nc) * tt_scale:8.3f} {tt_unit} " + f"per call: {_format_func_name(func)}" + ) + return None + + _print_top_functions( + stats_list, + "Most Called Functions", + key_func=lambda x: x[2], + format_line=format_most_called, + ) + + # Functions with highest per-call overhead + def format_overhead(stat): + func, _, nc, tt, _, _ = stat + if nc > 0 and tt > 0: + return ( + f"{(tt / nc) * tt_scale:8.3f} {tt_unit} per call, " + f"{nc:8d} calls: {_format_func_name(func)}" + ) + return None + + _print_top_functions( + stats_list, + "Functions with Highest Per-Call Overhead", + key_func=lambda x: x[3] / x[2] if x[2] > 0 else 0, + format_line=format_overhead, + ) + + # Functions with highest cumulative impact + def format_cumulative(stat): + func, _, nc, _, ct, _ = stat + if ct > 0: + return ( + f"{ct * ct_scale:8.3f} {ct_unit} cumulative time, " + f"{(ct / nc) * ct_scale:8.3f} {ct_unit} per call: " + f"{_format_func_name(func)}" + ) + return None + + _print_top_functions( + stats_list, + "Functions with Highest Cumulative Impact", + key_func=lambda x: x[4], + format_line=format_cumulative, + ) + + +def sample( + pid, + *, + sort=-1, + sample_interval_usec=100, + duration_sec=10, + filename=None, + all_threads=False, + limit=None, + show_summary=True, + output_format="pstats", +): + profiler = SampleProfiler( + pid, sample_interval_usec, all_threads=all_threads + ) + + collector = None + match output_format: + case "pstats": + collector = PstatsCollector(sample_interval_usec) + case "collapsed": + collector = CollapsedStackCollector() + filename = filename or f"collapsed.{pid}.txt" + case _: + raise ValueError(f"Invalid output format: {output_format}") + + profiler.sample(collector, duration_sec) + + if output_format == "pstats" and not filename: + stats = pstats.SampledStats(collector).strip_dirs() + print_sampled_stats(stats, sort, limit, show_summary) + else: + collector.export(filename) + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Sample a process's stack frames.\n\n" + "Sort options:\n" + " --sort-calls Sort by number of calls (most called functions first)\n" + " --sort-time Sort by total time (most time-consuming functions first)\n" + " --sort-cumulative Sort by cumulative time (functions with highest total impact first)\n" + " --sort-percall Sort by time per call (functions with highest per-call overhead first)\n" + " --sort-cumpercall Sort by cumulative time per call (functions with highest cumulative overhead per call)\n" + " --sort-name Sort by function name (alphabetical order)\n\n" + "The default sort is by cumulative time (--sort-cumulative)." + "Format descriptions:\n" + " pstats Standard Python profiler output format\n" + " collapsed Stack traces in collapsed format (file:function:line;file:function:line;... count)\n" + " Useful for generating flamegraphs with tools like flamegraph.pl" + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + color=True, + ) + parser.add_argument("pid", type=int, help="Process ID to sample.") + parser.add_argument( + "-i", + "--interval", + type=int, + default=10, + help="Sampling interval in microseconds (default: 10 usec)", + ) + parser.add_argument( + "-d", + "--duration", + type=int, + default=10, + help="Sampling duration in seconds (default: 10 seconds)", + ) + parser.add_argument( + "-a", + "--all-threads", + action="store_true", + help="Sample all threads in the process", + ) + parser.add_argument("-o", "--outfile", help="Save stats to ") + parser.add_argument( + "--no-color", + action="store_true", + help="Disable color output", + ) + parser.add_argument( + "-l", + "--limit", + type=int, + help="Limit the number of rows in the output", + ) + parser.add_argument( + "--no-summary", + action="store_true", + help="Disable the summary section at the end of the output", + ) + parser.add_argument( + "--format", + choices=["pstats", "collapsed"], + default="pstats", + help="Output format (default: pstats)", + ) + + # Add sorting options + sort_group = parser.add_mutually_exclusive_group() + sort_group.add_argument( + "--sort-calls", + action="store_const", + const=0, + dest="sort", + help="Sort by number of calls (most called functions first)", + ) + sort_group.add_argument( + "--sort-time", + action="store_const", + const=1, + dest="sort", + help="Sort by total time (most time-consuming functions first)", + ) + sort_group.add_argument( + "--sort-cumulative", + action="store_const", + const=2, + dest="sort", + help="Sort by cumulative time (functions with highest total impact first)", + ) + sort_group.add_argument( + "--sort-percall", + action="store_const", + const=3, + dest="sort", + help="Sort by time per call (functions with highest per-call overhead first)", + ) + sort_group.add_argument( + "--sort-cumpercall", + action="store_const", + const=4, + dest="sort", + help="Sort by cumulative time per call (functions with highest cumulative overhead per call)", + ) + sort_group.add_argument( + "--sort-name", + action="store_const", + const=5, + dest="sort", + help="Sort by function name (alphabetical order)", + ) + + # Set default sort to cumulative time + parser.set_defaults(sort=2) + + args = parser.parse_args() + + # Set color theme based on --no-color flag + if args.no_color: + _colorize.set_theme(_colorize.theme_no_color) + + sample( + args.pid, + sample_interval_usec=args.interval, + duration_sec=args.duration, + filename=args.outfile, + all_threads=args.all_threads, + limit=args.limit, + sort=args.sort, + show_summary=not args.no_summary, + output_format=args.format, + ) + + +if __name__ == "__main__": + main() diff --git a/Lib/profile/stack_collectors.py b/Lib/profile/stack_collectors.py new file mode 100644 index 00000000000000..fd4369356052b8 --- /dev/null +++ b/Lib/profile/stack_collectors.py @@ -0,0 +1,37 @@ +import collections +import os + +from .collector import Collector + + +class StackTraceCollector(Collector): + def __init__(self): + self.call_trees = [] + self.function_samples = collections.defaultdict(int) + + def collect(self, stack_frames): + for thread_id, frames in stack_frames: + if frames and len(frames) > 0: + # Store the complete call stack (reverse order - root first) + call_tree = list(reversed(frames)) + self.call_trees.append(call_tree) + + # Count samples per function + for frame in frames: + self.function_samples[frame] += 1 + + +class CollapsedStackCollector(StackTraceCollector): + def export(self, filename): + stack_counter = collections.Counter() + for call_tree in self.call_trees: + # Call tree is already in root->leaf order + stack_str = ";".join( + f"{os.path.basename(f[0])}:{f[2]}:{f[1]}" for f in call_tree + ) + stack_counter[stack_str] += 1 + + with open(filename, "w") as f: + for stack, count in stack_counter.items(): + f.write(f"{stack} {count}\n") + print(f"Collapsed stack output written to {filename}") diff --git a/Lib/pstats.py b/Lib/pstats.py index becaf35580eaee..079abd2c1b81df 100644 --- a/Lib/pstats.py +++ b/Lib/pstats.py @@ -139,7 +139,11 @@ def load_stats(self, arg): return elif isinstance(arg, str): with open(arg, 'rb') as f: - self.stats = marshal.load(f) + stats = marshal.load(f) + if (('__sampled__',)) in stats: + stats.pop((('__sampled__',))) + self.__class__ = SampledStats + self.stats = stats try: file_stats = os.stat(arg) arg = time.ctime(file_stats.st_mtime) + " " + arg @@ -467,7 +471,10 @@ def print_call_heading(self, name_size, column_title): subheader = isinstance(value, tuple) break if subheader: - print(" "*name_size + " ncalls tottime cumtime", file=self.stream) + self.print_call_subheading(name_size) + + def print_call_subheading(self, name_size): + print(" "*name_size + " ncalls tottime cumtime", file=self.stream) def print_call_line(self, name_size, source, call_dict, arrow="->"): print(func_std_string(source).ljust(name_size) + arrow, end=' ', file=self.stream) @@ -516,6 +523,35 @@ def print_line(self, func): # hack: should print percentages print(f8(ct/cc), end=' ', file=self.stream) print(func_std_string(func), file=self.stream) + +class SampledStats(Stats): + def __init__(self, *args, stream=None): + super().__init__(*args, stream=stream) + + self.sort_arg_dict = { + "samples" : (((1,-1), ), "sample count"), + "nsamples" : (((1,-1), ), "sample count"), + "cumtime" : (((3,-1), ), "cumulative time"), + "cumulative": (((3,-1), ), "cumulative time"), + "filename" : (((4, 1), ), "file name"), + "line" : (((5, 1), ), "line number"), + "module" : (((4, 1), ), "file name"), + "name" : (((6, 1), ), "function name"), + "nfl" : (((6, 1),(4, 1),(5, 1),), "name/file/line"), + "psamples" : (((0,-1), ), "primitive call count"), + "stdname" : (((7, 1), ), "standard name"), + "time" : (((2,-1), ), "internal time"), + "tottime" : (((2,-1), ), "internal time"), + } + + def print_call_subheading(self, name_size): + print(" "*name_size + " nsamples tottime cumtime", file=self.stream) + + def print_title(self): + print(' nsamples tottime persample cumtime persample', end=' ', file=self.stream) + print('filename:lineno(function)', file=self.stream) + + class TupleComp: """This class provides a generic function for comparing any two tuples. Each instance records a list of tuple-indices (from most significant @@ -607,6 +643,24 @@ def f8(x): # Statistics browser added by ESR, April 2001 #************************************************************************** +class StatsLoaderShim: + """Compatibility shim implementing 'create_stats' needed by Stats classes + to handle already unmarshalled data.""" + def __init__(self, raw_stats): + self.stats = raw_stats + + def create_stats(self): + pass + +def stats_factory(raw_stats): + """Return a Stats or SampledStats instance based on the marker in raw_stats.""" + if (('__sampled__',)) in raw_stats: + raw_stats = dict(raw_stats) # avoid mutating caller's dict + raw_stats.pop((('__sampled__',))) + return SampledStats(StatsLoaderShim(raw_stats)) + else: + return Stats(StatsLoaderShim(raw_stats)) + if __name__ == '__main__': import cmd try: @@ -693,7 +747,15 @@ def help_quit(self): def do_read(self, line): if line: try: - self.stats = Stats(line) + with open(line, 'rb') as f: + raw_stats = marshal.load(f) + self.stats = stats_factory(raw_stats) + try: + file_stats = os.stat(line) + arg = time.ctime(file_stats.st_mtime) + " " + line + except Exception: + arg = line + self.stats.files = [arg] except OSError as err: print(err.args[1], file=self.stream) return diff --git a/Lib/test/test_sample_profiler.py b/Lib/test/test_sample_profiler.py new file mode 100644 index 00000000000000..f678e8b23601f5 --- /dev/null +++ b/Lib/test/test_sample_profiler.py @@ -0,0 +1,561 @@ +"""Tests for the sampling profiler (profile.sample).""" + +import contextlib +import io +import marshal +import os +import subprocess +import sys +import tempfile +import time +import unittest +from unittest import mock + +import profile.sample +from profile.pstats_collector import PstatsCollector +from profile.stack_collectors import ( + CollapsedStackCollector, +) + +from test.support.os_helper import unlink + + +class MockFrameInfo: + """Mock FrameInfo for testing since the real one isn't accessible.""" + + def __init__(self, filename, lineno, funcname): + self.filename = filename + self.lineno = lineno + self.funcname = funcname + + def __repr__(self): + return f"MockFrameInfo(filename='{self.filename}', lineno={self.lineno}, funcname='{self.funcname}')" + + +@contextlib.contextmanager +def test_subprocess(script, startup_delay=0.1): + proc = subprocess.Popen( + [sys.executable, "-c", script], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + try: + if startup_delay > 0: + time.sleep(startup_delay) + yield proc + finally: + if proc.poll() is None: + proc.kill() + proc.wait() + + +def close_and_unlink(file): + file.close() + unlink(file.name) + + +class TestSampleProfilerComponents(unittest.TestCase): + """Unit tests for individual profiler components.""" + + def test_pstats_collector_basic(self): + """Test basic PstatsCollector functionality.""" + collector = PstatsCollector(sample_interval_usec=1000) + + # Test empty state + self.assertEqual(len(collector.result), 0) + self.assertEqual(len(collector.stats), 0) + + # Test collecting sample data + test_frames = [ + ( + 1, + [ + MockFrameInfo("file.py", 10, "func1"), + MockFrameInfo("file.py", 20, "func2"), + ], + ) + ] + collector.collect(test_frames) + + # Should have recorded calls for both functions + self.assertEqual(len(collector.result), 2) + self.assertIn(("file.py", 10, "func1"), collector.result) + self.assertIn(("file.py", 20, "func2"), collector.result) + + # Top-level function should have inline call + self.assertEqual( + collector.result[("file.py", 10, "func1")]["inline_calls"], 1 + ) + self.assertEqual( + collector.result[("file.py", 10, "func1")]["total_calls"], 1 + ) + + # Calling function should have total call + self.assertEqual( + collector.result[("file.py", 20, "func2")]["total_calls"], 1 + ) + + def test_pstats_collector_create_stats(self): + """Test PstatsCollector stats creation.""" + collector = PstatsCollector( + sample_interval_usec=1000000 + ) # 1 second intervals + + test_frames = [ + ( + 1, + [ + MockFrameInfo("file.py", 10, "func1"), + MockFrameInfo("file.py", 20, "func2"), + ], + ) + ] + collector.collect(test_frames) + collector.collect(test_frames) # Collect twice + + collector.create_stats() + + # Check stats format: (cc, nc, tt, ct, callers) + func1_stats = collector.stats[("file.py", 10, "func1")] + self.assertEqual(func1_stats[0], 2) # total_calls + self.assertEqual(func1_stats[1], 2) # nc (non-recursive calls) + self.assertEqual( + func1_stats[2], 2.0 + ) # tt (total time - 2 samples * 1 sec) + self.assertEqual(func1_stats[3], 2.0) # ct (cumulative time) + + func2_stats = collector.stats[("file.py", 20, "func2")] + self.assertEqual(func2_stats[0], 2) # total_calls + self.assertEqual(func2_stats[2], 0.0) # tt (no inline calls) + self.assertEqual(func2_stats[3], 2.0) # ct (cumulative time) + + def test_collapsed_stack_collector_basic(self): + collector = CollapsedStackCollector() + + # Test empty state + self.assertEqual(len(collector.call_trees), 0) + self.assertEqual(len(collector.function_samples), 0) + + # Test collecting sample data + test_frames = [ + (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")]) + ] + collector.collect(test_frames) + + # Should store call tree (reversed) + self.assertEqual(len(collector.call_trees), 1) + expected_tree = [("file.py", 20, "func2"), ("file.py", 10, "func1")] + self.assertEqual(collector.call_trees[0], expected_tree) + + # Should count function samples + self.assertEqual( + collector.function_samples[("file.py", 10, "func1")], 1 + ) + self.assertEqual( + collector.function_samples[("file.py", 20, "func2")], 1 + ) + + def test_collapsed_stack_collector_export(self): + collapsed_out = tempfile.NamedTemporaryFile(delete=False) + self.addCleanup(close_and_unlink, collapsed_out) + + collector = CollapsedStackCollector() + + test_frames1 = [ + (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")]) + ] + test_frames2 = [ + (1, [("file.py", 10, "func1"), ("file.py", 20, "func2")]) + ] # Same stack + test_frames3 = [(1, [("other.py", 5, "other_func")])] + + collector.collect(test_frames1) + collector.collect(test_frames2) + collector.collect(test_frames3) + + collector.export(collapsed_out.name) + # Check file contents + with open(collapsed_out.name, "r") as f: + content = f.read() + + lines = content.strip().split("\n") + self.assertEqual(len(lines), 2) # Two unique stacks + + # Check collapsed format: file:func:line;file:func:line count + stack1_expected = "file.py:func2:20;file.py:func1:10 2" + stack2_expected = "other.py:other_func:5 1" + + self.assertIn(stack1_expected, lines) + self.assertIn(stack2_expected, lines) + + def test_pstats_collector_export(self): + collector = PstatsCollector( + sample_interval_usec=1000000 + ) # 1 second intervals + + test_frames1 = [ + ( + 1, + [ + MockFrameInfo("file.py", 10, "func1"), + MockFrameInfo("file.py", 20, "func2"), + ], + ) + ] + test_frames2 = [ + ( + 1, + [ + MockFrameInfo("file.py", 10, "func1"), + MockFrameInfo("file.py", 20, "func2"), + ], + ) + ] # Same stack + test_frames3 = [(1, [MockFrameInfo("other.py", 5, "other_func")])] + + collector.collect(test_frames1) + collector.collect(test_frames2) + collector.collect(test_frames3) + + pstats_out = tempfile.NamedTemporaryFile( + suffix=".pstats", delete=False + ) + self.addCleanup(close_and_unlink, pstats_out) + collector.export(pstats_out.name) + + # Check file can be loaded with marshal + with open(pstats_out.name, "rb") as f: + stats_data = marshal.load(f) + + # Should be a dictionary with the sampled marker + self.assertIsInstance(stats_data, dict) + self.assertIn(("__sampled__",), stats_data) + self.assertTrue(stats_data[("__sampled__",)]) + + # Should have function data + function_entries = [ + k for k in stats_data.keys() if k != ("__sampled__",) + ] + self.assertGreater(len(function_entries), 0) + + # Check specific function stats format: (cc, nc, tt, ct, callers) + func1_key = ("file.py", 10, "func1") + func2_key = ("file.py", 20, "func2") + other_key = ("other.py", 5, "other_func") + + self.assertIn(func1_key, stats_data) + self.assertIn(func2_key, stats_data) + self.assertIn(other_key, stats_data) + + # Check func1 stats (should have 2 samples) + func1_stats = stats_data[func1_key] + self.assertEqual(func1_stats[0], 2) # total_calls + self.assertEqual(func1_stats[1], 2) # nc (non-recursive calls) + self.assertEqual(func1_stats[2], 2.0) # tt (total time) + self.assertEqual(func1_stats[3], 2.0) # ct (cumulative time) + + +class TestSampleProfilerIntegration(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.python_exe = sys.executable + + cls.test_script = ''' +import time +import os + +def slow_fibonacci(n): + """Recursive fibonacci - should show up prominently in profiler.""" + if n <= 1: + return n + return slow_fibonacci(n-1) + slow_fibonacci(n-2) + +def cpu_intensive_work(): + """CPU intensive work that should show in profiler.""" + result = 0 + for i in range(10000): + result += i * i + if i % 100 == 0: + result = result % 1000000 + return result + +def medium_computation(): + """Medium complexity function.""" + result = 0 + for i in range(100): + result += i * i + return result + +def fast_loop(): + """Fast simple loop.""" + total = 0 + for i in range(50): + total += i + return total + +def nested_calls(): + """Test nested function calls.""" + def level1(): + def level2(): + return medium_computation() + return level2() + return level1() + +def main_loop(): + """Main test loop with different execution paths.""" + iteration = 0 + + while True: + iteration += 1 + + # Different execution paths - focus on CPU intensive work + if iteration % 3 == 0: + # Very CPU intensive + result = cpu_intensive_work() + elif iteration % 5 == 0: + # Expensive recursive operation + result = slow_fibonacci(12) + else: + # Medium operation + result = nested_calls() + + # No sleep - keep CPU busy + +if __name__ == "__main__": + main_loop() +''' + + def test_sampling_basic_functionality(self): + with ( + test_subprocess(self.test_script) as proc, + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=2, + sample_interval_usec=1000, # 1ms + show_summary=False, + ) + except PermissionError: + self.skipTest("Insufficient permissions for remote profiling") + + output = captured_output.getvalue() + + # Basic checks on output + self.assertIn("Captured", output) + self.assertIn("samples", output) + self.assertIn("Profile Stats", output) + + # Should see some of our test functions + self.assertIn("slow_fibonacci", output) + + def test_sampling_with_pstats_export(self): + pstats_out = tempfile.NamedTemporaryFile( + suffix=".pstats", delete=False + ) + self.addCleanup(close_and_unlink, pstats_out) + + with test_subprocess(self.test_script) as proc: + # Suppress profiler output when testing file export + with ( + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=1, + filename=pstats_out.name, + sample_interval_usec=10000, + ) + except PermissionError: + self.skipTest( + "Insufficient permissions for remote profiling" + ) + + # Verify file was created and contains valid data + self.assertTrue(os.path.exists(pstats_out.name)) + self.assertGreater(os.path.getsize(pstats_out.name), 0) + + # Try to load the stats file + with open(pstats_out.name, "rb") as f: + stats_data = marshal.load(f) + + # Should be a dictionary with the sampled marker + self.assertIsInstance(stats_data, dict) + self.assertIn(("__sampled__",), stats_data) + self.assertTrue(stats_data[("__sampled__",)]) + + # Should have some function data + function_entries = [ + k for k in stats_data.keys() if k != ("__sampled__",) + ] + self.assertGreater(len(function_entries), 0) + + def test_sampling_with_collapsed_export(self): + collapsed_file = tempfile.NamedTemporaryFile( + suffix=".txt", delete=False + ) + self.addCleanup(close_and_unlink, collapsed_file) + + with ( + test_subprocess(self.test_script) as proc, + ): + # Suppress profiler output when testing file export + with ( + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=1, + filename=collapsed_file.name, + output_format="collapsed", + sample_interval_usec=10000, + ) + except PermissionError: + self.skipTest( + "Insufficient permissions for remote profiling" + ) + + # Verify file was created and contains valid data + self.assertTrue(os.path.exists(collapsed_file.name)) + self.assertGreater(os.path.getsize(collapsed_file.name), 0) + + # Check file format + with open(collapsed_file.name, "r") as f: + content = f.read() + + lines = content.strip().split("\n") + self.assertGreater(len(lines), 0) + + # Each line should have format: stack_trace count + for line in lines: + parts = line.rsplit(" ", 1) + self.assertEqual(len(parts), 2) + + stack_trace, count_str = parts + self.assertGreater(len(stack_trace), 0) + self.assertTrue(count_str.isdigit()) + self.assertGreater(int(count_str), 0) + + # Stack trace should contain semicolon-separated entries + if ";" in stack_trace: + stack_parts = stack_trace.split(";") + for part in stack_parts: + # Each part should be file:function:line + self.assertIn(":", part) + + def test_sampling_all_threads(self): + with ( + test_subprocess(self.test_script) as proc, + # Suppress profiler output + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=1, + all_threads=True, + sample_interval_usec=10000, + show_summary=False, + ) + except PermissionError: + self.skipTest("Insufficient permissions for remote profiling") + + # Just verify that sampling completed without error + # We're not testing output format here + + +class TestSampleProfilerErrorHandling(unittest.TestCase): + def test_invalid_pid(self): + with self.assertRaises((OSError, RuntimeError)): + profile.sample.sample(-1, duration_sec=1) + + def test_process_dies_during_sampling(self): + with test_subprocess("import time; time.sleep(0.5); exit()") as proc: + with ( + io.StringIO() as captured_output, + mock.patch("sys.stdout", captured_output), + ): + try: + profile.sample.sample( + proc.pid, + duration_sec=2, # Longer than process lifetime + sample_interval_usec=50000, + ) + except PermissionError: + self.skipTest( + "Insufficient permissions for remote profiling" + ) + + output = captured_output.getvalue() + + self.assertIn("Error rate", output) + + def test_invalid_output_format(self): + with self.assertRaises(ValueError): + profile.sample.sample( + os.getpid(), + duration_sec=1, + output_format="invalid_format", + ) + + +class TestSampleProfilerCLI(unittest.TestCase): + def test_argument_parsing_basic(self): + test_args = ["profile.sample", "12345"] + + with ( + mock.patch("sys.argv", test_args), + mock.patch("profile.sample.sample") as mock_sample, + ): + profile.sample.main() + + mock_sample.assert_called_once_with( + 12345, + sample_interval_usec=10, + duration_sec=10, + filename=None, + all_threads=False, + limit=None, + sort=2, + show_summary=True, + output_format="pstats", + ) + + def test_sort_options(self): + sort_options = [ + ("--sort-calls", 0), + ("--sort-time", 1), + ("--sort-cumulative", 2), + ("--sort-percall", 3), + ("--sort-cumpercall", 4), + ("--sort-name", 5), + ] + + for option, expected_sort_value in sort_options: + test_args = ["profile.sample", option, "12345"] + + with ( + mock.patch("sys.argv", test_args), + mock.patch("profile.sample.sample") as mock_sample, + ): + profile.sample.main() + + mock_sample.assert_called_once() + call_args = mock_sample.call_args[1] + self.assertEqual( + call_args["sort"], + expected_sort_value, + ) + mock_sample.reset_mock() + + +if __name__ == "__main__": + unittest.main()