Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions codeflash/languages/python/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
from codeflash.languages.python.context.unused_definition_remover import (
collect_top_level_defs_with_dependencies,
collect_top_level_defs_with_usages,
get_section_names,
is_assignment_used,
mark_defs_for_functions,
recurse_sections,
remove_unused_definitions_by_function_names,
)
from codeflash.languages.python.static_analysis.code_extractor import (
add_needed_imports_from_module,
find_preexisting_objects,
gather_source_imports,
)
from codeflash.models.models import (
CodeContextType,
Expand Down Expand Up @@ -64,6 +67,7 @@ class FileContextCache:
helper_functions: list[FunctionSource]
file_path: Path
relative_path: Path
gathered_imports: Any = None


@dataclass
Expand Down Expand Up @@ -265,15 +269,23 @@ def extract_all_contexts_from_files(
except ValueError:
relative_path = file_path

# Compute defs once for fto_names and reuse across remove + prune
fto_defs = collect_top_level_defs_with_usages(original_module, fto_names)
# Collect definitions + dependencies once (expensive CST traversal), reuse for both mark passes
base_defs = collect_top_level_defs_with_dependencies(original_module)
fto_defs = mark_defs_for_functions(base_defs, fto_names)
# Clean by fto_names only (for RW)
rw_cleaned = remove_unused_definitions_by_function_names(original_module, fto_names, defs_with_usages=fto_defs)
# Clean by all names (for RO/HASH/TESTGEN) — reuse rw_cleaned if no extra HoH names
# Clean by all names (for RO/HASH/TESTGEN) — reuse base_defs to avoid re-traversal
all_names = fto_names | hoh_names
all_cleaned = (
remove_unused_definitions_by_function_names(original_module, all_names) if hoh_names else rw_cleaned
)
if hoh_names:
all_defs = mark_defs_for_functions(base_defs, all_names)
all_cleaned = remove_unused_definitions_by_function_names(
original_module, all_names, defs_with_usages=all_defs
)
else:
all_cleaned = rw_cleaned

# Pre-compute source imports once for this file (avoids 3x CST traversal of original_module)
src_gathered = gather_source_imports(original_module, file_path, project_root_path)

# READ_WRITABLE
try:
Expand All @@ -293,6 +305,7 @@ def extract_all_contexts_from_files(
dst_path=file_path,
project_root=project_root_path,
helper_functions=rw_helper_functions,
gathered_imports=src_gathered,
)
rw.code_strings.append(CodeString(code=rw_code, file_path=relative_path))
except ValueError as e:
Expand All @@ -311,6 +324,7 @@ def extract_all_contexts_from_files(
dst_path=file_path,
project_root=project_root_path,
helper_functions=all_helper_functions,
gathered_imports=src_gathered,
)
ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path))
except ValueError as e:
Expand Down Expand Up @@ -340,6 +354,7 @@ def extract_all_contexts_from_files(
dst_path=file_path,
project_root=project_root_path,
helper_functions=all_helper_functions,
gathered_imports=src_gathered,
)
testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path))
except ValueError as e:
Expand All @@ -354,6 +369,7 @@ def extract_all_contexts_from_files(
helper_functions=all_helper_functions,
file_path=file_path,
relative_path=relative_path,
gathered_imports=src_gathered,
)
)

Expand Down Expand Up @@ -381,6 +397,9 @@ def extract_all_contexts_from_files(

cleaned = remove_unused_definitions_by_function_names(original_module, hoh_names)

# Pre-compute source imports once for this file
src_gathered = gather_source_imports(original_module, file_path, project_root_path)

# READ_ONLY
try:
ro_pruned = parse_code_and_prune_cst(
Expand All @@ -394,6 +413,7 @@ def extract_all_contexts_from_files(
dst_path=file_path,
project_root=project_root_path,
helper_functions=helper_functions,
gathered_imports=src_gathered,
)
ro.code_strings.append(CodeString(code=ro_code, file_path=relative_path))
except ValueError as e:
Expand Down Expand Up @@ -423,6 +443,7 @@ def extract_all_contexts_from_files(
dst_path=file_path,
project_root=project_root_path,
helper_functions=helper_functions,
gathered_imports=src_gathered,
)
testgen.code_strings.append(CodeString(code=testgen_code, file_path=relative_path))
except ValueError as e:
Expand All @@ -437,6 +458,7 @@ def extract_all_contexts_from_files(
helper_functions=helper_functions,
file_path=file_path,
relative_path=relative_path,
gathered_imports=src_gathered,
)
)

Expand Down Expand Up @@ -473,6 +495,7 @@ def re_extract_from_cache(
dst_path=file_cache.file_path,
project_root=project_root_path,
helper_functions=file_cache.helper_functions,
gathered_imports=file_cache.gathered_imports,
)
result.code_strings.append(CodeString(code=code, file_path=file_cache.relative_path))
return result
Expand Down
38 changes: 28 additions & 10 deletions codeflash/languages/python/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,23 +408,41 @@ def remove_unused_definitions_recursively(
)


def collect_top_level_defs_with_usages(
code: Union[str, cst.Module], qualified_function_names: set[str]
) -> dict[str, UsageInfo]:
"""Collect all top level definitions (classes, variables or functions) and their usages."""
def collect_top_level_defs_with_dependencies(code: Union[str, cst.Module]) -> dict[str, UsageInfo]:
"""Collect all top level definitions and their inter-definition dependencies (expensive CST traversal).

Returns a definitions dict with dependencies populated but no usage marks set.
This result can be reused across multiple mark_defs_for_functions calls to avoid
repeating the expensive MetadataWrapper + DependencyCollector traversal.
"""
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
# Collect all definitions (top level classes, variables or function)
definitions = collect_top_level_definitions(module)

# Collect dependencies between definitions using the visitor pattern
wrapper = cst.MetadataWrapper(module)
dependency_collector = DependencyCollector(definitions)
wrapper.visit(dependency_collector)
return definitions

# Mark definitions used by specified functions, and their dependencies recursively
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)

def mark_defs_for_functions(
base_defs: dict[str, UsageInfo], qualified_function_names: set[str]
) -> dict[str, UsageInfo]:
"""Create a copy of definitions with usage marks set for the given function names.

This is cheap (dict copy + graph walk) and can be called multiple times with
different function name sets on the same base_defs without re-traversing the CST.
"""
marked = {k: UsageInfo(name=v.name, dependencies=v.dependencies) for k, v in base_defs.items()}
usage_marker = QualifiedFunctionUsageMarker(marked, qualified_function_names)
usage_marker.mark_used_definitions()
return definitions
return marked


def collect_top_level_defs_with_usages(
code: Union[str, cst.Module], qualified_function_names: set[str]
) -> dict[str, UsageInfo]:
"""Collect all top level definitions (classes, variables or functions) and their usages."""
base_defs = collect_top_level_defs_with_dependencies(code)
return mark_defs_for_functions(base_defs, qualified_function_names)


def remove_unused_definitions_by_function_names(
Expand Down
86 changes: 53 additions & 33 deletions codeflash/languages/python/static_analysis/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from codeflash.models.models import FunctionSource


_SENTINEL = object()


class GlobalFunctionCollector(cst.CSTVisitor):
"""Collects all module-level function definitions (not inside classes or other functions)."""

Expand Down Expand Up @@ -540,41 +543,22 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
return set()


def add_needed_imports_from_module(
src_module_code: str | cst.Module,
dst_module_code: str | cst.Module,
src_path: Path,
dst_path: Path,
project_root: Path,
helper_functions: list[FunctionSource] | None = None,
helper_functions_fqn: set[str] | None = None,
) -> str:
"""Add all needed and used source module code imports to the destination module code, and return it."""
if not helper_functions_fqn:
helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])}

# Cache the fallback early to avoid repeated isinstance checks
if isinstance(dst_module_code, str):
dst_code_fallback = dst_module_code
else:
# Keep Module-input fallback formatting aligned with transformed_module.code.lstrip("\n").
dst_code_fallback = dst_module_code.code.lstrip("\n")
def gather_source_imports(
src_module_code: str | cst.Module, src_path: Path, project_root: Path
) -> GatherImportsVisitor | None:
"""Pre-process source module to gather its imports. Returns None if no imports found.

This is the expensive part of add_needed_imports_from_module (CST traversal of src).
When adding imports from the same source to multiple destinations, call this once
and pass the result to add_needed_imports_from_module via gathered_imports.
"""
src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path)
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)

dst_context: CodemodContext = CodemodContext(
filename=src_path.name,
full_module_name=dst_module_and_package.name,
full_package_name=dst_module_and_package.package,
)
try:
if isinstance(src_module_code, cst.Module):
src_module = src_module_code.visit(FutureAliasedImportTransformer())
else:
src_module = cst.parse_module(src_module_code).visit(FutureAliasedImportTransformer())

# Early exit: check if source has any imports at module level
has_module_level_imports = any(
isinstance(s, (cst.Import, cst.ImportFrom))
for stmt in src_module.body
Expand All @@ -583,7 +567,7 @@ def add_needed_imports_from_module(
)

if not has_module_level_imports:
return dst_code_fallback
return None

gatherer: GatherImportsVisitor = GatherImportsVisitor(
CodemodContext(
Expand All @@ -593,25 +577,61 @@ def add_needed_imports_from_module(
)
)

# Exclude function/class bodies so GatherImportsVisitor only sees module-level imports.
# Nested imports (inside functions) are part of function logic and must not be
# scheduled for add/remove — RemoveImportsVisitor would strip them as "unused".
module_level_only = src_module.with_changes(
body=[stmt for stmt in src_module.body if not isinstance(stmt, (cst.FunctionDef, cst.ClassDef))]
)
module_level_only.visit(gatherer)

# Early exit: if no imports were gathered, return destination as-is
if (
not gatherer.module_imports
and not gatherer.object_mapping
and not gatherer.module_aliases
and not gatherer.alias_mapping
):
return dst_code_fallback
return None

return gatherer
except Exception as e:
logger.error(f"Error parsing source module code: {e}")
return None


def add_needed_imports_from_module(
src_module_code: str | cst.Module,
dst_module_code: str | cst.Module,
src_path: Path,
dst_path: Path,
project_root: Path,
helper_functions: list[FunctionSource] | None = None,
helper_functions_fqn: set[str] | None = None,
gathered_imports: GatherImportsVisitor | None | object = _SENTINEL,
) -> str:
"""Add all needed and used source module code imports to the destination module code, and return it."""
if not helper_functions_fqn:
helper_functions_fqn = {f.fully_qualified_name for f in (helper_functions or [])}

# Cache the fallback early to avoid repeated isinstance checks
if isinstance(dst_module_code, str):
dst_code_fallback = dst_module_code
else:
# Keep Module-input fallback formatting aligned with transformed_module.code.lstrip("\n").
dst_code_fallback = dst_module_code.code.lstrip("\n")

dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)

dst_context: CodemodContext = CodemodContext(
filename=src_path.name,
full_module_name=dst_module_and_package.name,
full_package_name=dst_module_and_package.package,
)

# Use pre-computed gatherer if provided, otherwise compute on the fly
if gathered_imports is _SENTINEL:
gatherer = gather_source_imports(src_module_code, src_path, project_root)
else:
gatherer = gathered_imports

if gatherer is None:
return dst_code_fallback

dotted_import_collector = DottedImportCollector()
Expand Down
4 changes: 2 additions & 2 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ class VerificationType(str, Enum):
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init


@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class InvocationId:
test_module_path: str # The fully qualified name of the test module
test_class_name: Optional[str] # The name of the class where the test is defined
Expand Down Expand Up @@ -821,7 +821,7 @@ def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId
)


@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class FunctionTestInvocation:
loop_index: int # The loop index of the function invocation, starts at 1
id: InvocationId # The fully qualified name of the function invocation (id)
Expand Down
30 changes: 14 additions & 16 deletions codeflash/verification/parse_test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,8 @@ def merge_test_results(
) -> TestResults:
merged_test_results = TestResults()

grouped_xml_results: defaultdict[str, TestResults] = defaultdict(TestResults)
grouped_bin_results: defaultdict[str, TestResults] = defaultdict(TestResults)
grouped_xml_results: defaultdict[tuple[str, str, str, int], TestResults] = defaultdict(TestResults)
grouped_bin_results: defaultdict[tuple[str, str, str, int], TestResults] = defaultdict(TestResults)

# This is done to match the right iteration_id which might not be available in the xml
for result in xml_test_results:
Expand All @@ -606,24 +606,22 @@ def merge_test_results(
test_function_name = result.id.test_function_name

grouped_xml_results[
(result.id.test_module_path or "")
+ ":"
+ (result.id.test_class_name or "")
+ ":"
+ (test_function_name or "")
+ ":"
+ str(result.loop_index)
(
result.id.test_module_path or "",
result.id.test_class_name or "",
test_function_name or "",
result.loop_index,
)
].add(result)

for result in bin_test_results:
grouped_bin_results[
(result.id.test_module_path or "")
+ ":"
+ (result.id.test_class_name or "")
+ ":"
+ (result.id.test_function_name or "")
+ ":"
+ str(result.loop_index)
(
result.id.test_module_path or "",
result.id.test_class_name or "",
result.id.test_function_name or "",
result.loop_index,
)
].add(result)

for result_id in grouped_xml_results:
Expand Down
Loading