Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions codeflash/languages/python/parse_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import os
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from junitparser.xunit2 import JUnitXml

Expand Down Expand Up @@ -48,7 +48,7 @@
)


def _parse_func(file_path: Path):
def _parse_func(file_path: Path) -> Any:
from lxml.etree import XMLParser, parse

xml_parser = XMLParser(huge_tree=True)
Expand All @@ -59,13 +59,22 @@ def parse_python_test_xml(
test_xml_file_path: Path,
test_files: TestFiles,
test_config: TestConfig,
run_result: subprocess.CompletedProcess | None = None,
run_result: subprocess.CompletedProcess[str] | None = None,
) -> TestResults:
from codeflash.verification.parse_test_output import resolve_test_file_from_class_path

test_results = TestResults()
if not test_xml_file_path.exists():
logger.warning(f"No test results for {test_xml_file_path} found.")
if run_result is not None and run_result.returncode != 0:
stderr_snippet = (run_result.stderr or "")[:500]
stdout_snippet = (run_result.stdout or "")[:500]
logger.warning(
f"No test results for {test_xml_file_path} found. "
f"Subprocess exited with code {run_result.returncode}.\n"
f"stdout: {stdout_snippet}\nstderr: {stderr_snippet}"
)
else:
logger.warning(f"No test results for {test_xml_file_path} found.")
console.rule()
return test_results
try:
Expand All @@ -87,12 +96,7 @@ def parse_python_test_xml(
):
logger.info("Test failed to load, skipping it.")
if run_result is not None:
if isinstance(run_result.stdout, str) and isinstance(run_result.stderr, str):
logger.info(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}")
else:
logger.info(
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}"
)
logger.info(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}")
return test_results

test_class_path = testcase.classname
Expand Down Expand Up @@ -159,7 +163,7 @@ def parse_python_test_xml(
sys_stdout = testcase.system_out or ""

begin_matches = list(matches_re_start.finditer(sys_stdout))
end_matches: dict[tuple, re.Match] = {}
end_matches: dict[tuple[str, ...], re.Match[str]] = {}
for match in matches_re_end.finditer(sys_stdout):
groups = match.groups()
if len(groups[5].split(":")) > 1:
Expand Down Expand Up @@ -234,11 +238,5 @@ def parse_python_test_xml(
f"Tests '{[test_file.original_file_path for test_file in test_files.test_files]}' failed to run, skipping"
)
if run_result is not None:
stdout, stderr = "", ""
try:
stdout = run_result.stdout.decode()
stderr = run_result.stderr.decode()
except AttributeError:
stdout = run_result.stderr
logger.debug(f"Test log - STDOUT : {stdout} \n STDERR : {stderr}")
logger.debug(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}")
return test_results
50 changes: 28 additions & 22 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from functools import lru_cache
from pathlib import Path
from re import Pattern
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, cast
from typing import TYPE_CHECKING, Any, NamedTuple, Optional

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
from pydantic.dataclasses import dataclass

from codeflash.models.test_type import TestType

if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Generator

import libcst as cst
from rich.tree import Tree
Expand Down Expand Up @@ -298,11 +298,13 @@ def flat(self) -> str:

"""
if self._cache.get("flat") is not None:
return self._cache["flat"]
self._cache["flat"] = "\n".join(
result: str = self._cache["flat"]
return result
flat: str = "\n".join(
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
)
return self._cache["flat"]
self._cache["flat"] = flat
return flat

@property
def markdown(self) -> str:
Expand Down Expand Up @@ -332,7 +334,8 @@ def file_to_path(self) -> dict[str, str]:

"""
try:
return self._cache["file_to_path"]
cached: dict[str, str] = self._cache["file_to_path"]
return cached
except KeyError:
mapping = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
self._cache["file_to_path"] = mapping
Expand Down Expand Up @@ -426,13 +429,16 @@ class TestFile(BaseModel):

class TestFiles(BaseModel):
test_files: list[TestFile]
_seen_paths: set[Path] = PrivateAttr(default_factory=set)

def model_post_init(self, __context: Any, /) -> None:
self._seen_paths = {tf.instrumented_behavior_file_path for tf in self.test_files}

def add(self, test_file: TestFile) -> None:
if test_file not in self.test_files:
key = test_file.instrumented_behavior_file_path
if key not in self._seen_paths:
self._seen_paths.add(key)
self.test_files.append(test_file)
else:
msg = "Test file already exists in the list"
raise ValueError(msg)

def get_by_original_file_path(self, file_path: Path) -> TestFile | None:
normalized = self._normalize_path_for_comparison(file_path)
Expand Down Expand Up @@ -494,8 +500,8 @@ def _normalize_path_for_comparison(path: Path) -> str:
# Only lowercase on Windows where filesystem is case-insensitive
return resolved.lower() if sys.platform == "win32" else resolved

def __iter__(self) -> Iterator[TestFile]:
return iter(self.test_files)
def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058
yield from self.test_files

def __len__(self) -> int:
return len(self.test_files)
Expand All @@ -514,9 +520,9 @@ class CandidateEvaluationContext:
optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
is_correct: dict[str, bool] = Field(default_factory=dict)
optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
ast_code_to_id: dict = Field(default_factory=dict)
ast_code_to_id: dict[str, Any] = Field(default_factory=dict)
optimizations_post: dict[str, str] = Field(default_factory=dict)
valid_optimizations: list = Field(default_factory=list)
valid_optimizations: list[Any] = Field(default_factory=list)

def record_failed_candidate(self, optimization_id: str) -> None:
"""Record results for a failed candidate."""
Expand All @@ -543,7 +549,7 @@ def handle_duplicate_candidate(
# Copy results from the previous evaluation (use .get() in case past_opt_id was registered
# but never benchmarked due to an unhandled exception in process_single_candidate)
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios.get(past_opt_id)
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id)
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id, False)
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes.get(past_opt_id)

# Line profiler results only available for successful runs
Expand Down Expand Up @@ -631,7 +637,7 @@ class OriginalCodeBaseline(BaseModel):
behavior_test_results: TestResults
benchmarking_test_results: TestResults
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
line_profile_results: dict
line_profile_results: dict[str, Any]
runtime: int
coverage_results: Optional[CoverageData]
async_throughput: Optional[int] = None
Expand Down Expand Up @@ -793,7 +799,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
f"// Testing function: {self.function_getting_tested}"
)

if self.test_class_name:
if self.test_class_name and self.test_function_name:
for stmt in module_node.body:
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
func_node = self.find_func_in_class(stmt, self.test_function_name)
Expand Down Expand Up @@ -884,7 +890,7 @@ def group_by_benchmarks(
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
from codeflash.code_utils.code_utils import module_name_from_file_path

test_results_by_benchmark = defaultdict(TestResults)
test_results_by_benchmark: defaultdict[BenchmarkKey, TestResults] = defaultdict(TestResults)
benchmark_module_path = {}
for benchmark_key in benchmark_keys:
benchmark_module_path[benchmark_key] = module_name_from_file_path(
Expand Down Expand Up @@ -1015,7 +1021,7 @@ def effective_loop_count(self) -> int:
return max(loop_indices) if loop_indices else 0

def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
map_gen_test_file_to_no_of_tests = Counter()
map_gen_test_file_to_no_of_tests: Counter[Path] = Counter()
for gen_test_result in self.test_results:
if (
gen_test_result.test_type == TestType.GENERATED_REGRESSION
Expand All @@ -1024,8 +1030,8 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1
return map_gen_test_file_to_no_of_tests

def __iter__(self) -> Iterator[FunctionTestInvocation]:
return iter(self.test_results)
def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058
yield from self.test_results

def __len__(self) -> int:
return len(self.test_results)
Expand All @@ -1051,7 +1057,7 @@ def __eq__(self, other: object) -> bool:
if len(self) != len(other):
return False
original_recursion_limit = sys.getrecursionlimit()
cast("TestResults", other)
assert isinstance(other, TestResults)
for test_result in self:
other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id)
if other_test_result is None:
Expand Down
44 changes: 44 additions & 0 deletions tests/test_test_files_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from pathlib import Path

from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType


class TestTestFilesAdd:
def test_add_unique_test_file(self) -> None:
tf = TestFiles(test_files=[])
test_file = TestFile(
instrumented_behavior_file_path=Path("/tmp/test_behavior.py"),
benchmarking_file_path=Path("/tmp/test_perf.py"),
test_type=TestType.GENERATED_REGRESSION,
)
tf.add(test_file)
assert len(tf.test_files) == 1
assert tf.test_files[0] is test_file

def test_add_duplicate_is_noop(self) -> None:
tf = TestFiles(test_files=[])
test_file = TestFile(
instrumented_behavior_file_path=Path("/tmp/test_behavior.py"),
benchmarking_file_path=Path("/tmp/test_perf.py"),
test_type=TestType.GENERATED_REGRESSION,
)
tf.add(test_file)
tf.add(test_file) # silent skip — first write wins
assert len(tf.test_files) == 1

def test_add_many_files_performance(self) -> None:
tf = TestFiles(test_files=[])
for i in range(100):
test_file = TestFile(
instrumented_behavior_file_path=Path(f"/tmp/test_behavior_{i}.py"),
benchmarking_file_path=Path(f"/tmp/test_perf_{i}.py"),
test_type=TestType.GENERATED_REGRESSION,
)
tf.add(test_file)

assert len(tf.test_files) == 100
assert len(tf._seen_paths) == 100
# Verify all paths are unique in the set
expected_paths = {Path(f"/tmp/test_behavior_{i}.py") for i in range(100)}
assert tf._seen_paths == expected_paths
Loading