Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,13 @@ def get_run_tmp_file(file_path: Path | str) -> Path:
return _run_tmpdir_path / file_path


@lru_cache(maxsize=1)
def _get_site_packages_paths() -> tuple[Path, ...]:
return tuple(Path(p).resolve() for p in site.getsitepackages())


def path_belongs_to_site_packages(file_path: Path) -> bool:
file_path_resolved = file_path.resolve()
site_packages = [Path(p).resolve() for p in site.getsitepackages()]
return any(file_path_resolved.is_relative_to(site_package_path) for site_package_path in site_packages)
return any(file_path.resolve().is_relative_to(sp) for sp in _get_site_packages_paths())


def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
Expand Down
186 changes: 122 additions & 64 deletions codeflash/discovery/functions_to_optimize.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
def _extract_names_from_annotation(self, node: cst.CSTNode) -> None:
if isinstance(node, cst.Name):
name = node.value
if name in self.definitions and name != self.current_top_level_name and self.current_top_level_name:
if (
name in self.definitions
and name != self.current_top_level_name
and self.current_top_level_name
and self.current_top_level_name in self.definitions
):
self.definitions[self.current_top_level_name].dependencies.add(name)
elif isinstance(node, cst.Subscript):
self._extract_names_from_annotation(node.value)
Expand Down
7 changes: 3 additions & 4 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,9 @@ def handle_duplicate_candidate(
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes.get(past_opt_id)

# Line profiler results only available for successful runs
if past_opt_id in self.optimized_line_profiler_results:
self.optimized_line_profiler_results[candidate.optimization_id] = self.optimized_line_profiler_results[
past_opt_id
]
line_profiler_result = self.optimized_line_profiler_results.get(past_opt_id)
if line_profiler_result is not None:
self.optimized_line_profiler_results[candidate.optimization_id] = line_profiler_result

self.optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][
"shorter_source_code"
Expand Down
4 changes: 2 additions & 2 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,8 @@ def worktree_mode(self) -> Result[bool, str]:
return Success(True) # noqa: FBT003

def mirror_paths_for_worktree_mode(self, worktree_dir: Path) -> None:
original_args = copy.deepcopy(self.args)
original_test_cfg = copy.deepcopy(self.test_cfg)
original_args = copy.copy(self.args)
original_test_cfg = copy.copy(self.test_cfg)
self.original_args_and_test_cfg = (original_args, original_test_cfg)

original_git_root = git_root_dir()
Expand Down
46 changes: 26 additions & 20 deletions codeflash/tracing/profile_stats.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,44 @@
import json
from __future__ import annotations

import pstats
import sqlite3
from copy import copy
from pathlib import Path
from typing import Any, TextIO

from codeflash.cli_cmds.console import logger


class ProfileStats(pstats.Stats):
# Attributes set by pstats.Stats.init() — stubs don't expose them
files: list[str]
stream: TextIO
top_level: set[tuple[str, int, str]]
total_calls: int
prim_calls: int
total_tt: float
max_name_len: int
fcn_list: list[tuple[str, int, str]] | None
sort_arg_dict: dict[str, tuple[Any, ...]]
all_callees: dict[tuple[str, int, str], dict[tuple[str, int, str], tuple[int, int, float, float]]] | None
stats: dict[tuple[str, int, str], tuple[int, int, int | float, int | float, dict[Any, Any]]]

def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None:
assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist"
assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}"
self.trace_file_path = trace_file_path
self.time_unit = time_unit
logger.debug(hasattr(self, "create_stats"))
super().__init__(copy(self))
super().__init__(copy(self)) # type: ignore[arg-type] # pstats uses duck-typed create_stats interface

def create_stats(self) -> None:
self.con = sqlite3.connect(self.trace_file_path)
cur = self.con.cursor()
pdata = cur.execute("SELECT * FROM pstats").fetchall()
pdata = cur.execute(
"SELECT filename, line_number, function, class_name,"
" call_count_nonrecursive, num_callers, total_time_ns, cumulative_time_ns"
" FROM pstats"
).fetchall()
self.con.close()
time_conversion_factor = {"ns": 1, "us": 1e3, "ms": 1e6, "s": 1e9}[self.time_unit]
self.stats = {}
Expand All @@ -32,31 +51,18 @@ def create_stats(self) -> None:
num_callers,
total_time_ns,
cumulative_time_ns,
callers,
) in pdata:
loaded_callers = json.loads(callers)
unmapped_callers = {}
for caller in loaded_callers:
caller_key = caller["key"]
if isinstance(caller_key, list):
caller_key = tuple(caller_key)
elif not isinstance(caller_key, tuple):
caller_key = (caller_key,) if not isinstance(caller_key, (list, tuple)) else tuple(caller_key)
unmapped_callers[caller_key] = caller["value"]

# Create function key with class name if present (matching tracer.py format)
function_name = f"{class_name}.{function}" if class_name else function

self.stats[(filename, line_number, function_name)] = (
call_count_nonrecursive,
num_callers,
total_time_ns / time_conversion_factor if time_conversion_factor != 1 else total_time_ns,
cumulative_time_ns / time_conversion_factor if time_conversion_factor != 1 else cumulative_time_ns,
unmapped_callers,
{},
)

def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002
# Copied from pstats.Stats.print_stats and modified to print the correct time unit
def print_stats(self, *amount: str | float) -> ProfileStats:
for filename in self.files:
print(filename, file=self.stream)
if self.files:
Expand All @@ -74,8 +80,8 @@ def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002
_width, list_ = self.get_print_list(amount)
if list_:
self.print_title()
for func in list_:
self.print_line(func)
for fn in list_:
self.print_line(fn)
print(file=self.stream)
print(file=self.stream)
return self
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ module = ["jedi", "jedi.api.classes", "inquirer", "inquirer.themes", "numba", "d
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["codeflash.code_utils.codeflash_wrap_decorator", "codeflash.code_utils.instrument_existing_tests", "codeflash.optimization.optimizer"]
module = ["codeflash.code_utils.codeflash_wrap_decorator", "codeflash.code_utils.instrument_existing_tests", "codeflash.optimization.optimizer", "codeflash.languages.python.context.unused_definition_remover"]
disable_error_code = ["attr-defined", "return-value", "no-untyped-call", "no-untyped-def", "arg-type", "assignment", "var-annotated", "no-any-return", "call-overload", "union-attr", "unreachable", "list-item"]

[tool.pydantic-mypy]
Expand Down
17 changes: 11 additions & 6 deletions tests/test_code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from codeflash.code_utils.code_utils import (
_get_site_packages_paths,
cleanup_paths,
exit_with_message,
file_name_from_test_module_name,
Expand Down Expand Up @@ -263,6 +264,7 @@ def test_get_run_tmp_file_reuses_temp_directory() -> None:
def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages").resolve()]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
_get_site_packages_paths.cache_clear()

file_path = Path("/usr/local/lib/python3.9/site-packages/some_package")
assert path_belongs_to_site_packages(file_path) is True
Expand All @@ -271,6 +273,7 @@ def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytes
def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
_get_site_packages_paths.cache_clear()

file_path = Path("/usr/local/lib/python3.9/other_directory/some_package")
assert path_belongs_to_site_packages(file_path) is False
Expand All @@ -279,6 +282,7 @@ def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch: p
def test_path_belongs_to_site_packages_with_relative_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
_get_site_packages_paths.cache_clear()

file_path = Path("some_package")
assert path_belongs_to_site_packages(file_path) is False
Expand All @@ -298,6 +302,7 @@ def test_path_belongs_to_site_packages_with_symlinked_site_packages(
package_file.write_text("# package file")

monkeypatch.setattr(site, "getsitepackages", lambda: [str(symlinked_site_packages)])
_get_site_packages_paths.cache_clear()

assert path_belongs_to_site_packages(package_file) is True

Expand All @@ -321,6 +326,7 @@ def test_path_belongs_to_site_packages_with_complex_symlinks(monkeypatch: pytest

site_packages_via_links = link2 / "lib" / "python3.9" / "site-packages"
monkeypatch.setattr(site, "getsitepackages", lambda: [str(site_packages_via_links)])
_get_site_packages_paths.cache_clear()

assert path_belongs_to_site_packages(package_file) is True

Expand All @@ -341,6 +347,7 @@ def test_path_belongs_to_site_packages_resolved_paths_normalization(

complex_site_packages_path = tmp_path / "lib" / "python3.9" / "other" / ".." / "site-packages" / "."
monkeypatch.setattr(site, "getsitepackages", lambda: [str(complex_site_packages_path)])
_get_site_packages_paths.cache_clear()

assert path_belongs_to_site_packages(package_file) is True

Expand Down Expand Up @@ -380,18 +387,16 @@ def my_function():


@pytest.fixture
def mock_code_context():
def mock_code_context() -> MagicMock:
"""Mock CodeOptimizationContext for testing extract_dependent_function."""
from unittest.mock import MagicMock

from codeflash.models.models import CodeOptimizationContext

context = MagicMock(spec=CodeOptimizationContext)
context.preexisting_objects = []
return context


def test_extract_dependent_function_sync_and_async(mock_code_context):
def test_extract_dependent_function_sync_and_async(mock_code_context: MagicMock) -> None:
"""Test extract_dependent_function with both sync and async functions."""
# Test sync function extraction
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
Expand All @@ -417,7 +422,7 @@ async def async_helper_function():
assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"


def test_extract_dependent_function_edge_cases(mock_code_context):
def test_extract_dependent_function_edge_cases(mock_code_context: MagicMock) -> None:
"""Test extract_dependent_function edge cases."""
# No dependent functions
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
Expand All @@ -441,7 +446,7 @@ async def helper2():
assert extract_dependent_function("main_function", mock_code_context) is False


def test_extract_dependent_function_mixed_scenarios(mock_code_context):
def test_extract_dependent_function_mixed_scenarios(mock_code_context: MagicMock) -> None:
"""Test extract_dependent_function with mixed sync/async scenarios."""
# Async main with sync helper
mock_code_context.testgen_context = CodeStringsMarkdown.parse_markdown_code("""```python:file.py
Expand Down
141 changes: 141 additions & 0 deletions tests/test_discovery_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from __future__ import annotations

import ast
import tempfile
from pathlib import Path
from unittest.mock import patch

from codeflash.discovery.functions_to_optimize import (
discovery_cache,
find_all_functions_in_file,
inspect_top_level_functions_or_methods,
parse_ast_cached,
read_file_cached,
)


def test_read_file_cached_without_context_manager(tmp_path: Path) -> None:
f = tmp_path / "sample.py"
f.write_text("x = 1\n", encoding="utf-8")
assert read_file_cached(f) == "x = 1\n"


def test_read_file_cached_returns_same_object_within_context(tmp_path: Path) -> None:
f = tmp_path / "sample.py"
f.write_text("x = 1\n", encoding="utf-8")
with discovery_cache():
result1 = read_file_cached(f)
result2 = read_file_cached(f)
assert result1 is result2


def test_read_file_cached_does_not_persist_across_contexts(tmp_path: Path) -> None:
f = tmp_path / "sample.py"
f.write_text("x = 1\n", encoding="utf-8")
with discovery_cache():
result1 = read_file_cached(f)
f.write_text("x = 2\n", encoding="utf-8")
with discovery_cache():
result2 = read_file_cached(f)
assert result1 != result2


def test_parse_ast_cached_returns_same_object_within_context(tmp_path: Path) -> None:
f = tmp_path / "sample.py"
f.write_text("def foo():\n return 1\n", encoding="utf-8")
with discovery_cache():
tree1 = parse_ast_cached(f)
tree2 = parse_ast_cached(f)
assert tree1 is tree2
assert isinstance(tree1, ast.Module)


def test_parse_ast_cached_uses_provided_source(tmp_path: Path) -> None:
f = tmp_path / "sample.py"
f.write_text("x = 1\n", encoding="utf-8")
source = "y = 2\n"
with discovery_cache():
tree = parse_ast_cached(f, source=source)
assert any(
isinstance(n, ast.Assign)
and isinstance(n.targets[0], ast.Name)
and n.targets[0].id == "y"
for n in ast.walk(tree)
)


def test_discovery_cache_avoids_redundant_reads(tmp_path: Path) -> None:
f = tmp_path / "module.py"
f.write_text("def bar():\n return 42\n", encoding="utf-8")
with discovery_cache():
with patch.object(Path, "read_text", wraps=f.read_text) as mock_read:
read_file_cached(f)
read_file_cached(f)
read_file_cached(f)
assert mock_read.call_count == 1


def test_find_all_functions_in_file_uses_cache(tmp_path: Path) -> None:
f = tmp_path / "module.py"
f.write_text("def compute(x):\n return x * 2\n", encoding="utf-8")
with discovery_cache():
result = find_all_functions_in_file(f)
assert f in result
assert result[f][0].function_name == "compute"


def test_inspect_top_level_functions_uses_cache(tmp_path: Path) -> None:
f = tmp_path / "module.py"
f.write_text("def top_func(a, b):\n return a + b\n", encoding="utf-8")
with discovery_cache():
props = inspect_top_level_functions_or_methods(f, "top_func")
assert props is not None
assert props.is_top_level
assert props.has_args


def test_find_and_inspect_share_cached_content(tmp_path: Path) -> None:
f = tmp_path / "module.py"
f.write_text(
"class MyClass:\n def method(self):\n return 1\n\ndef standalone():\n return 2\n",
encoding="utf-8",
)
with discovery_cache():
with patch.object(Path, "read_text", wraps=f.read_text) as mock_read:
find_all_functions_in_file(f)
props = inspect_top_level_functions_or_methods(f, "method", class_name="MyClass")
assert mock_read.call_count == 1
assert props is not None
assert props.is_top_level


def test_discovery_results_correct_with_multiple_files(tmp_path: Path) -> None:
f1 = tmp_path / "a.py"
f1.write_text("def alpha():\n return 'a'\n", encoding="utf-8")
f2 = tmp_path / "b.py"
f2.write_text("def beta(x):\n return x + 1\n", encoding="utf-8")

with discovery_cache():
r1 = find_all_functions_in_file(f1)
r2 = find_all_functions_in_file(f2)

assert r1[f1][0].function_name == "alpha"
assert r2[f2][0].function_name == "beta"


def test_cache_handles_invalid_syntax_gracefully(tmp_path: Path) -> None:
f = tmp_path / "broken.py"
f.write_text("def incomplete(:\n", encoding="utf-8")
with discovery_cache():
result = find_all_functions_in_file(f)
assert result == {}


def test_cache_handles_nonexistent_file_in_parse_ast(tmp_path: Path) -> None:
f = tmp_path / "nonexistent.py"
with discovery_cache():
try:
parse_ast_cached(f)
assert False, "Should have raised"
except (FileNotFoundError, OSError):
pass
Loading
Loading