Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions codeflash/tracing/profile_stats.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,44 @@
import json
from __future__ import annotations

import pstats
import sqlite3
from copy import copy
from pathlib import Path
from typing import Any, TextIO

from codeflash.cli_cmds.console import logger


class ProfileStats(pstats.Stats):
# Attributes set by pstats.Stats.init() — stubs don't expose them
files: list[str]
stream: TextIO
top_level: set[tuple[str, int, str]]
total_calls: int
prim_calls: int
total_tt: float
max_name_len: int
fcn_list: list[tuple[str, int, str]] | None
sort_arg_dict: dict[str, tuple[Any, ...]]
all_callees: dict[tuple[str, int, str], dict[tuple[str, int, str], tuple[int, int, float, float]]] | None
stats: dict[tuple[str, int, str], tuple[int, int, int | float, int | float, dict[Any, Any]]]

def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None:
assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist"
assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}"
self.trace_file_path = trace_file_path
self.time_unit = time_unit
logger.debug(hasattr(self, "create_stats"))
super().__init__(copy(self))
super().__init__(copy(self)) # type: ignore[arg-type] # pstats uses duck-typed create_stats interface

def create_stats(self) -> None:
self.con = sqlite3.connect(self.trace_file_path)
cur = self.con.cursor()
pdata = cur.execute("SELECT * FROM pstats").fetchall()
pdata = cur.execute(
"SELECT filename, line_number, function, class_name,"
" call_count_nonrecursive, num_callers, total_time_ns, cumulative_time_ns"
" FROM pstats"
).fetchall()
self.con.close()
time_conversion_factor = {"ns": 1, "us": 1e3, "ms": 1e6, "s": 1e9}[self.time_unit]
self.stats = {}
Expand All @@ -32,31 +51,18 @@ def create_stats(self) -> None:
num_callers,
total_time_ns,
cumulative_time_ns,
callers,
) in pdata:
loaded_callers = json.loads(callers)
unmapped_callers = {}
for caller in loaded_callers:
caller_key = caller["key"]
if isinstance(caller_key, list):
caller_key = tuple(caller_key)
elif not isinstance(caller_key, tuple):
caller_key = (caller_key,) if not isinstance(caller_key, (list, tuple)) else tuple(caller_key)
unmapped_callers[caller_key] = caller["value"]

# Create function key with class name if present (matching tracer.py format)
function_name = f"{class_name}.{function}" if class_name else function

self.stats[(filename, line_number, function_name)] = (
call_count_nonrecursive,
num_callers,
total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns,
cumulative_time_ns / time_conversion_factor if time_conversion_factor != 1 else cumulative_time_ns,
unmapped_callers,
{},
)

def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002
# Copied from pstats.Stats.print_stats and modified to print the correct time unit
def print_stats(self, *amount: str | float) -> ProfileStats:
for filename in self.files:
print(filename, file=self.stream)
if self.files:
Expand All @@ -74,8 +80,8 @@ def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002
_width, list_ = self.get_print_list(amount)
if list_:
self.print_title()
for func in list_:
self.print_line(func)
for fn in list_:
self.print_line(fn)
print(file=self.stream)
print(file=self.stream)
return self
Expand Down
Loading