Skip to content
186 changes: 122 additions & 64 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

if TYPE_CHECKING:
from argparse import Namespace
from collections.abc import Generator

from codeflash.models.models import CodeOptimizationContext
from codeflash.verification.verification_utils import TestConfig
Expand All @@ -51,6 +52,46 @@ class FunctionProperties:
staticmethod_class_name: Optional[str]


# =============================================================================
# Discovery-scoped file/AST cache
# =============================================================================

_active_discovery_cache: dict[Path, str] | None = None
_active_ast_cache: dict[Path, ast.Module] | None = None


@contextlib.contextmanager
def discovery_cache() -> Generator[None, None, None]:
global _active_discovery_cache, _active_ast_cache
_active_discovery_cache = {}
_active_ast_cache = {}
try:
yield
finally:
_active_discovery_cache = None
_active_ast_cache = None


def read_file_cached(file_path: Path) -> str:
if _active_discovery_cache is not None:
if file_path not in _active_discovery_cache:
_active_discovery_cache[file_path] = file_path.read_text(encoding="utf-8")
return _active_discovery_cache[file_path]
return file_path.read_text(encoding="utf-8")


def parse_ast_cached(file_path: Path, source: str | None = None) -> ast.Module:
if _active_ast_cache is not None:
if file_path not in _active_ast_cache:
if source is None:
source = read_file_cached(file_path)
_active_ast_cache[file_path] = ast.parse(source)
return _active_ast_cache[file_path]
if source is None:
source = read_file_cached(file_path)
return ast.parse(source)


# =============================================================================
# Multi-language support helpers
# =============================================================================
Expand Down Expand Up @@ -135,7 +176,9 @@ def get_files_for_language(
return files


def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bool, str | None]:
def _is_js_ts_function_exported(
file_path: Path, function_name: str, source: str | None = None
) -> tuple[bool, str | None]:
"""Check if a JavaScript/TypeScript function is exported from its module.

For JS/TS, functions that are not exported cannot be imported by tests,
Expand All @@ -144,6 +187,7 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo
Args:
file_path: Path to the source file.
function_name: Name of the function to check.
source: Pre-read file content. If None, reads from disk.

Returns:
Tuple of (is_exported, export_name). export_name may be 'default' for default exports.
Expand All @@ -152,16 +196,16 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo
from codeflash.languages.javascript.treesitter import get_analyzer_for_file

try:
source = file_path.read_text(encoding="utf-8")
if source is None:
source = read_file_cached(file_path)
analyzer = get_analyzer_for_file(file_path)
return analyzer.is_function_exported(source, function_name)
except Exception as e:
logger.debug(f"Failed to check export status for {function_name}: {e}")
# Return True to avoid blocking in case of errors
return True, None


def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: str) -> bool:
def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: str, source: str | None = None) -> bool:
"""Check if a JS/TS function exists in the file but is not exported.

Returns True only if the function name is found as a defined function
Expand All @@ -170,7 +214,8 @@ def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: s
from codeflash.languages.javascript.treesitter import get_analyzer_for_file

try:
source = file_path.read_text(encoding="utf-8")
if source is None:
source = read_file_cached(file_path)
analyzer = get_analyzer_for_file(file_path)
all_funcs = analyzer.find_functions(
source, include_methods=True, include_arrow_functions=True, require_name=True
Expand All @@ -183,27 +228,6 @@ def _is_js_ts_function_exists_but_not_exported(file_path: Path, function_name: s
return False


def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
"""Find all optimizable functions using the language support abstraction.

This function uses the registered language support for the file's language
to discover functions, then converts them to FunctionToOptimize instances.
"""
from codeflash.languages.base import FunctionFilterCriteria

functions: dict[Path, list[FunctionToOptimize]] = {}

try:
lang_support = get_language_support(file_path)
require_return = lang_support.language != Language.JAVA
criteria = FunctionFilterCriteria(require_return=require_return)
functions[file_path] = lang_support.discover_functions(file_path, criteria)
except Exception as e:
logger.debug(f"Failed to discover functions in {file_path}: {e}")

return functions


def get_functions_to_optimize(
optimize_all: str | None,
replay_test: list[Path] | None,
Expand All @@ -221,12 +245,14 @@ def get_functions_to_optimize(
functions: dict[Path, list[FunctionToOptimize]]
trace_file_path: Path | None = None
is_lsp = is_LSP_enabled()
with warnings.catch_warnings():
with discovery_cache(), warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=SyntaxWarning)
if optimize_all:
logger.info("!lsp|Finding all functions in the module '%s'…", optimize_all)
console.rule()
functions = get_all_files_and_functions(Path(optimize_all), ignore_paths)
functions = get_all_files_and_functions(
Path(optimize_all), ignore_paths, tests_root=test_cfg.tests_root, module_root=module_root
)
elif replay_test:
functions, trace_file_path = get_all_replay_test_functions(
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
Expand All @@ -236,6 +262,13 @@ def get_functions_to_optimize(
console.rule()
file = Path(file) if isinstance(file, str) else file
functions = find_all_functions_in_file(file)
# Source already cached by find_all_functions_in_file above
_js_ts_source: str | None = None
if only_get_this_function is not None and is_language_supported(file):
_lang = get_language_support(file)
if _lang.language in (Language.JAVASCRIPT, Language.TYPESCRIPT):
with contextlib.suppress(Exception):
_js_ts_source = read_file_cached(file)
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
Expand All @@ -260,15 +293,13 @@ def get_functions_to_optimize(
return functions, 0, None

# For JS/TS: check if the function exists but is not exported
if is_language_supported(file):
lang_support = get_language_support(file)
if lang_support.language in (Language.JAVASCRIPT, Language.TYPESCRIPT):
if _is_js_ts_function_exists_but_not_exported(file, only_function_name):
exit_with_message(
f"Function '{only_function_name}' exists in {file} but is not exported.\n"
f"In JavaScript/TypeScript, only exported functions can be optimized.\n"
f"Add: export {{ {only_function_name} }}"
)
if _js_ts_source is not None:
if _is_js_ts_function_exists_but_not_exported(file, only_function_name, source=_js_ts_source):
exit_with_message(
f"Function '{only_function_name}' exists in {file} but is not exported.\n"
f"In JavaScript/TypeScript, only exported functions can be optimized.\n"
f"Add: export {{ {only_function_name} }}"
)

found = closest_matching_file_function_name(only_get_this_function, functions)
if found is not None:
Expand All @@ -295,7 +326,7 @@ def get_functions_to_optimize(
# It's a standalone function - check if the function is exported
name_to_check = found_function.function_name

is_exported, _ = _is_js_ts_function_exported(file, name_to_check)
is_exported, _ = _is_js_ts_function_exported(file, name_to_check, source=_js_ts_source)
if not is_exported:
if found_function.parents:
logger.debug(
Expand All @@ -317,7 +348,12 @@ def get_functions_to_optimize(
logger.info("Finding all functions modified in the current git diff ...")
console.rule()
ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff(uncommitted_changes=False)
functions = get_functions_within_git_diff(
uncommitted_changes=False,
tests_root=test_cfg.tests_root,
ignore_paths=ignore_paths,
module_root=module_root,
)
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
)
Expand All @@ -326,9 +362,16 @@ def get_functions_to_optimize(
return filtered_modified_functions, functions_count, trace_file_path


def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[Path, list[FunctionToOptimize]]:
def get_functions_within_git_diff(
uncommitted_changes: bool,
tests_root: Path | None = None,
ignore_paths: list[Path] | None = None,
module_root: Path | None = None,
) -> dict[Path, list[FunctionToOptimize]]:
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
return get_functions_within_lines(modified_lines)
return get_functions_within_lines(
modified_lines, tests_root=tests_root, ignore_paths=ignore_paths, module_root=module_root
)


def closest_matching_file_function_name(
Expand Down Expand Up @@ -406,12 +449,20 @@ def get_functions_inside_a_commit(commit_hash: str) -> dict[Path, list[FunctionT
return get_functions_within_lines(modified_lines)


def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Path, list[FunctionToOptimize]]:
def get_functions_within_lines(
modified_lines: dict[str, list[int]],
tests_root: Path | None = None,
ignore_paths: list[Path] | None = None,
module_root: Path | None = None,
) -> dict[Path, list[FunctionToOptimize]]:
functions: dict[Path, list[FunctionToOptimize]] = {}
for path_str, lines_in_file in modified_lines.items():
path = Path(path_str)
if not path.exists():
continue
if tests_root is not None and module_root is not None:
if not filter_files_optimized(path, tests_root, ignore_paths or [], module_root):
continue
all_functions = find_all_functions_in_file(path)
functions[path] = [
func
Expand All @@ -424,21 +475,30 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Pat


def get_all_files_and_functions(
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
module_root_path: Path,
ignore_paths: list[Path],
language: Language | None = None,
tests_root: Path | None = None,
module_root: Path | None = None,
) -> dict[Path, list[FunctionToOptimize]]:
"""Get all optimizable functions from files in the module root.

Args:
module_root_path: Root path to search for source files.
ignore_paths: List of paths to ignore.
language: Optional specific language to filter for. If None, includes all supported languages.
tests_root: Test root path for prefiltering files before reading (avoids unnecessary I/O).
module_root: Module root path for prefiltering files before reading.

Returns:
Dictionary mapping file paths to lists of FunctionToOptimize.

"""
functions: dict[Path, list[FunctionToOptimize]] = {}
for file_path in get_files_for_language(module_root_path, ignore_paths, language):
if tests_root is not None and module_root is not None:
if not filter_files_optimized(file_path, tests_root, ignore_paths, module_root):
continue
functions.update(find_all_functions_in_file(file_path).items())
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
# Helpful if an optimize-all run is stuck and we restart it.
Expand All @@ -457,7 +517,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
lang_support = get_language_support(file_path)
require_return = lang_support.language != Language.JAVA
criteria = FunctionFilterCriteria(require_return=require_return)
source = file_path.read_text(encoding="utf-8")
source = read_file_cached(file_path)
return {file_path: lang_support.discover_functions(source, file_path, criteria)}
except Exception as e:
logger.debug(f"Failed to discover functions in {file_path}: {e}")
Expand All @@ -474,21 +534,20 @@ def get_all_replay_test_functions(
trace_file_path: Path | None = None
for replay_test_file in replay_test:
try:
with replay_test_file.open("r", encoding="utf8") as f:
tree = ast.parse(f.read())
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if (
isinstance(target, ast.Name)
and target.id == "trace_file_path"
and isinstance(node.value, ast.Constant)
and isinstance(node.value.value, str)
):
trace_file_path = Path(node.value.value)
break
if trace_file_path:
tree = parse_ast_cached(replay_test_file)
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
if (
isinstance(target, ast.Name)
and target.id == "trace_file_path"
and isinstance(node.value, ast.Constant)
and isinstance(node.value.value, str)
):
trace_file_path = Path(node.value.value)
break
if trace_file_path:
break
if trace_file_path:
break
except Exception as e:
Expand Down Expand Up @@ -602,7 +661,7 @@ def _get_java_replay_test_functions(
from codeflash.languages.registry import get_language_support

lang_support = get_language_support(source_file)
source_code = source_file.read_text(encoding="utf-8")
source_code = read_file_cached(source_file)
all_functions = lang_support.discover_functions(source_code, source_file)

for func in all_functions:
Expand Down Expand Up @@ -730,11 +789,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
def inspect_top_level_functions_or_methods(
file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None
) -> FunctionProperties | None:
with file_name.open(encoding="utf8") as file:
try:
ast_module = ast.parse(file.read())
except Exception:
return None
try:
ast_module = parse_ast_cached(file_name)
except Exception:
return None
visitor = TopLevelFunctionOrMethodVisitor(
file_name=file_name, function_or_method_name=function_or_method_name, class_name=class_name, line_no=line_no
)
Expand Down
Loading
Loading