Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
lang_support = get_language_support(file_path)
require_return = lang_support.language != Language.JAVA
criteria = FunctionFilterCriteria(require_return=require_return)
functions[file_path] = lang_support.discover_functions(file_path, criteria)
source = file_path.read_text(encoding="utf-8")
functions[file_path] = lang_support.discover_functions(source, file_path, criteria)
except Exception as e:
logger.debug(f"Failed to discover functions in {file_path}: {e}")

Expand Down Expand Up @@ -226,7 +227,9 @@ def get_functions_to_optimize(
if optimize_all:
logger.info("!lsp|Finding all functions in the module '%s'…", optimize_all)
console.rule()
functions = get_all_files_and_functions(Path(optimize_all), ignore_paths)
functions = get_all_files_and_functions(
Path(optimize_all), ignore_paths, tests_root=test_cfg.tests_root, module_root=module_root
)
elif replay_test:
functions, trace_file_path = get_all_replay_test_functions(
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
Expand Down Expand Up @@ -317,7 +320,12 @@ def get_functions_to_optimize(
logger.info("Finding all functions modified in the current git diff ...")
console.rule()
ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff(uncommitted_changes=False)
functions = get_functions_within_git_diff(
uncommitted_changes=False,
tests_root=test_cfg.tests_root,
ignore_paths=ignore_paths,
module_root=module_root,
)
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
)
Expand All @@ -326,9 +334,16 @@ def get_functions_to_optimize(
return filtered_modified_functions, functions_count, trace_file_path


def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[Path, list[FunctionToOptimize]]:
def get_functions_within_git_diff(
uncommitted_changes: bool,
tests_root: Path | None = None,
ignore_paths: list[Path] | None = None,
module_root: Path | None = None,
) -> dict[Path, list[FunctionToOptimize]]:
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
return get_functions_within_lines(modified_lines)
return get_functions_within_lines(
modified_lines, tests_root=tests_root, ignore_paths=ignore_paths, module_root=module_root
)


def closest_matching_file_function_name(
Expand Down Expand Up @@ -406,12 +421,20 @@ def get_functions_inside_a_commit(commit_hash: str) -> dict[Path, list[FunctionT
return get_functions_within_lines(modified_lines)


def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Path, list[FunctionToOptimize]]:
def get_functions_within_lines(
modified_lines: dict[str, list[int]],
tests_root: Path | None = None,
ignore_paths: list[Path] | None = None,
module_root: Path | None = None,
) -> dict[Path, list[FunctionToOptimize]]:
functions: dict[Path, list[FunctionToOptimize]] = {}
for path_str, lines_in_file in modified_lines.items():
path = Path(path_str)
if not path.exists():
continue
if tests_root is not None and module_root is not None:
if not filter_files_optimized(path, tests_root, ignore_paths or [], module_root):
continue
all_functions = find_all_functions_in_file(path)
functions[path] = [
func
Expand All @@ -424,21 +447,30 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Pat


def get_all_files_and_functions(
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
module_root_path: Path,
ignore_paths: list[Path],
language: Language | None = None,
tests_root: Path | None = None,
module_root: Path | None = None,
) -> dict[Path, list[FunctionToOptimize]]:
"""Get all optimizable functions from files in the module root.

Args:
module_root_path: Root path to search for source files.
ignore_paths: List of paths to ignore.
language: Optional specific language to filter for. If None, includes all supported languages.
tests_root: Test root path for prefiltering files before reading (avoids unnecessary I/O).
module_root: Module root path for prefiltering files before reading.

Returns:
Dictionary mapping file paths to lists of FunctionToOptimize.

"""
functions: dict[Path, list[FunctionToOptimize]] = {}
for file_path in get_files_for_language(module_root_path, ignore_paths, language):
if tests_root is not None and module_root is not None:
if not filter_files_optimized(file_path, tests_root, ignore_paths, module_root):
continue
functions.update(find_all_functions_in_file(file_path).items())
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
# Helpful if an optimize-all run is stuck and we restart it.
Expand Down
181 changes: 181 additions & 0 deletions tests/test_discovery_prefilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

from pathlib import Path
from unittest.mock import patch

from codeflash.discovery.functions_to_optimize import (
get_all_files_and_functions,
get_functions_within_lines,
)


def test_prefilter_skips_test_files(tmp_path: Path) -> None:
"""Files in tests_root should be skipped before read_text() is called."""
module_root = tmp_path / "src"
module_root.mkdir()
tests_root = tmp_path / "tests"
tests_root.mkdir()

source_file = module_root / "app.py"
source_file.write_text("def compute():\n return 1\n", encoding="utf-8")

test_file = tests_root / "test_app.py"
test_file.write_text("def test_compute():\n return True\n", encoding="utf-8")

with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files:
mock_get_files.return_value = [source_file, test_file]
result = get_all_files_and_functions(
module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root
)

assert source_file in result
assert test_file not in result


def test_prefilter_skips_ignored_paths(tmp_path: Path) -> None:
"""Files in ignore_paths should be skipped before read_text() is called."""
module_root = tmp_path / "src"
module_root.mkdir()
tests_root = tmp_path / "tests"
tests_root.mkdir()
ignored_dir = module_root / "vendor"
ignored_dir.mkdir()

source_file = module_root / "app.py"
source_file.write_text("def compute():\n return 1\n", encoding="utf-8")

vendor_file = ignored_dir / "lib.py"
vendor_file.write_text("def helper():\n return 2\n", encoding="utf-8")

with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files:
mock_get_files.return_value = [source_file, vendor_file]
result = get_all_files_and_functions(
module_root, ignore_paths=[ignored_dir], tests_root=tests_root, module_root=module_root
)

assert source_file in result
assert vendor_file not in result


def test_prefilter_skips_files_outside_module_root(tmp_path: Path) -> None:
"""Files outside module_root should be skipped before read_text() is called."""
module_root = tmp_path / "src"
module_root.mkdir()
tests_root = tmp_path / "tests"
tests_root.mkdir()
other_dir = tmp_path / "other"
other_dir.mkdir()

source_file = module_root / "app.py"
source_file.write_text("def compute():\n return 1\n", encoding="utf-8")

outside_file = other_dir / "stray.py"
outside_file.write_text("def stray():\n return 3\n", encoding="utf-8")

with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files:
mock_get_files.return_value = [source_file, outside_file]
result = get_all_files_and_functions(
module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root
)

assert source_file in result
assert outside_file not in result


def test_prefilter_disabled_without_params(tmp_path: Path) -> None:
"""Without tests_root/module_root, no prefiltering occurs (backward compat)."""
module_root = tmp_path / "src"
module_root.mkdir()

source_file = module_root / "app.py"
source_file.write_text("def compute():\n return 1\n", encoding="utf-8")

with patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files:
mock_get_files.return_value = [source_file]
result = get_all_files_and_functions(module_root, ignore_paths=[])

assert source_file in result


def test_prefilter_in_get_functions_within_lines(tmp_path: Path) -> None:
"""get_functions_within_lines should skip test files when prefilter params are provided."""
module_root = tmp_path / "src"
module_root.mkdir()
tests_root = tmp_path / "tests"
tests_root.mkdir()

source_file = module_root / "app.py"
source_file.write_text("def compute():\n return 1\n", encoding="utf-8")

test_file = tests_root / "test_app.py"
test_file.write_text("def test_compute():\n return True\n", encoding="utf-8")

modified_lines = {
str(source_file): [1, 2],
str(test_file): [1, 2],
}

result = get_functions_within_lines(
modified_lines, tests_root=tests_root, ignore_paths=[], module_root=module_root
)

assert source_file in result
assert test_file not in result


def test_prefilter_avoids_reading_skipped_files(tmp_path: Path) -> None:
"""Verify that find_all_functions_in_file is NOT called for prefiltered files (the core perf win)."""
module_root = tmp_path / "src"
module_root.mkdir()
tests_root = tmp_path / "tests"
tests_root.mkdir()

source_file = module_root / "app.py"
source_file.write_text("def compute():\n return 1\n", encoding="utf-8")

test_file = tests_root / "test_app.py"
test_file.write_text("def test_compute():\n return True\n", encoding="utf-8")

with (
patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files,
patch("codeflash.discovery.functions_to_optimize.find_all_functions_in_file") as mock_find,
):
mock_get_files.return_value = [source_file, test_file]
mock_find.return_value = {}
get_all_files_and_functions(
module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root
)

# find_all_functions_in_file (which does read_text) should only be called for source_file
assert mock_find.call_count == 1
mock_find.assert_called_once_with(source_file)


def test_prefilter_skips_submodule_paths(tmp_path: Path) -> None:
"""Submodule paths should be skipped by prefilter."""
module_root = tmp_path / "src"
module_root.mkdir()
tests_root = tmp_path / "tests"
tests_root.mkdir()
submodule_dir = module_root / "vendor_submodule"
submodule_dir.mkdir()

source_file = module_root / "app.py"
source_file.write_text("def compute():\n return 1\n", encoding="utf-8")

submodule_file = submodule_dir / "lib.py"
submodule_file.write_text("def helper():\n return 2\n", encoding="utf-8")

with (
patch("codeflash.discovery.functions_to_optimize.get_files_for_language") as mock_get_files,
patch(
"codeflash.discovery.functions_to_optimize.ignored_submodule_paths", return_value=[submodule_dir]
),
):
mock_get_files.return_value = [source_file, submodule_file]
result = get_all_files_and_functions(
module_root, ignore_paths=[], tests_root=tests_root, module_root=module_root
)

assert source_file in result
assert submodule_file not in result
Loading