Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 45 additions & 51 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from functools import lru_cache
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Generator

import tomlkit

Expand Down Expand Up @@ -112,7 +116,7 @@ def normalize_by_max(values: list[float]) -> list[float]:
return [v / mx for v in values]


def create_score_dictionary_from_metrics(weights: list[float], *metrics: list[float]) -> dict[int, int]:
def create_score_dictionary_from_metrics(weights: list[float], *metrics: list[float]) -> dict[int, float]:
"""Combine multiple metrics into a single weighted score dictionary.

Each metric is a list of values (smaller = better).
Expand Down Expand Up @@ -208,67 +212,53 @@ def filter_args(addopts_args: list[str]) -> list[str]:
def modify_addopts(config_file: Path) -> tuple[str, bool]:
file_type = config_file.suffix.lower()
filename = config_file.name
config = None
if file_type not in {".toml", ".ini", ".cfg"} or not config_file.exists():
return "", False
# Read original file
with Path.open(config_file, encoding="utf-8") as f:
content = f.read()
try:
if filename == "pyproject.toml":
# use tomlkit
data = tomlkit.parse(content)
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
# nothing to do if no addopts present
if original_addopts == "":
return content, False
if isinstance(original_addopts, list):
original_addopts = " ".join(original_addopts)
original_addopts = original_addopts.replace("=", " ")
addopts_args = (
original_addopts.split()
) # any number of space characters as delimiter, doesn't look at = which is fine
else:
# use configparser
config = configparser.ConfigParser()
config.read_string(content)
data = {section: dict(config[section]) for section in config.sections()}
if config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}:
original_addopts = data.get("pytest", {}).get("addopts", "") # should only be a string
else:
original_addopts = data.get("tool:pytest", {}).get("addopts", "") # should only be a string
original_addopts = original_addopts.replace("=", " ")
addopts_args = original_addopts.split()
new_addopts_args = filter_args(addopts_args)
if new_addopts_args == addopts_args:
return content, False
# change addopts now
if file_type == ".toml":
data["tool"]["pytest"]["ini_options"]["addopts"] = " ".join(new_addopts_args)
# Write modified file
new_addopts_args = filter_args(addopts_args)
if new_addopts_args == addopts_args:
return content, False
data["tool"]["pytest"]["ini_options"]["addopts"] = " ".join(new_addopts_args) # type: ignore[index]
with Path.open(config_file, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(data))
return content, True
elif config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}:
config.set("pytest", "addopts", " ".join(new_addopts_args))
# Write modified file
with Path.open(config_file, "w", encoding="utf-8") as f:
config.write(f)
return content, True
return content, True
config = configparser.ConfigParser()
config.read_string(content)
ini_data = {section: dict(config[section]) for section in config.sections()}
if config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}:
original_addopts = ini_data.get("pytest", {}).get("addopts", "")
else:
config.set("tool:pytest", "addopts", " ".join(new_addopts_args))
# Write modified file
with Path.open(config_file, "w", encoding="utf-8") as f:
config.write(f)
return content, True
original_addopts = ini_data.get("tool:pytest", {}).get("addopts", "")
original_addopts = original_addopts.replace("=", " ")
addopts_args = original_addopts.split()
new_addopts_args = filter_args(addopts_args)
if new_addopts_args == addopts_args:
return content, False
section = "pytest" if config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"} else "tool:pytest"
config.set(section, "addopts", " ".join(new_addopts_args))
with Path.open(config_file, "w", encoding="utf-8") as f:
config.write(f)
return content, True

except Exception:
logger.debug("Trouble parsing")
return content, False # not modified
return content, False


@contextmanager
def custom_addopts() -> None:
def custom_addopts() -> Generator[None, None, None]:
closest_config_files = get_all_closest_config_files()

original_content = {}
Expand All @@ -287,18 +277,17 @@ def custom_addopts() -> None:


@contextmanager
def add_addopts_to_pyproject() -> None:
def add_addopts_to_pyproject() -> Generator[None, None, None]:
pyproject_file = find_pyproject_toml()
original_content = None
original_content: str | None = None
try:
# Read original file
if pyproject_file.exists():
with Path.open(pyproject_file, encoding="utf-8") as f:
original_content = f.read()
data = tomlkit.parse(original_content)
data["tool"]["pytest"] = {}
data["tool"]["pytest"]["ini_options"] = {}
data["tool"]["pytest"]["ini_options"]["addopts"] = [
data["tool"]["pytest"] = {} # type: ignore[index]
data["tool"]["pytest"]["ini_options"] = {} # type: ignore[index]
data["tool"]["pytest"]["ini_options"]["addopts"] = [ # type: ignore[index]
"-n=auto",
"-n",
"1",
Expand All @@ -312,9 +301,9 @@ def add_addopts_to_pyproject() -> None:
yield

finally:
# Restore original file
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
f.write(original_content)
if original_content is not None:
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
f.write(original_content)


def encoded_tokens_len(s: str) -> int:
Expand Down Expand Up @@ -418,13 +407,18 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
return True, function_names


_run_tmpdir: TemporaryDirectory[str] | None = None
_run_tmpdir_path: Path | None = None


def get_run_tmp_file(file_path: Path | str) -> Path:
global _run_tmpdir, _run_tmpdir_path
if isinstance(file_path, str):
file_path = Path(file_path)
if not hasattr(get_run_tmp_file, "tmpdir_path"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
get_run_tmp_file.tmpdir_path = Path(get_run_tmp_file.tmpdir.name).resolve()
return get_run_tmp_file.tmpdir_path / file_path
if _run_tmpdir_path is None:
_run_tmpdir = TemporaryDirectory(prefix="codeflash_")
_run_tmpdir_path = Path(_run_tmpdir.name).resolve()
return _run_tmpdir_path / file_path


def path_belongs_to_site_packages(file_path: Path) -> bool:
Expand Down
39 changes: 21 additions & 18 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from functools import lru_cache
from pathlib import Path
from re import Pattern
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, cast
from typing import TYPE_CHECKING, Any, NamedTuple, Optional

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator
from pydantic.dataclasses import dataclass

from codeflash.models.test_type import TestType

if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Generator

import libcst as cst
from rich.tree import Tree
Expand Down Expand Up @@ -298,11 +298,13 @@ def flat(self) -> str:

"""
if self._cache.get("flat") is not None:
return self._cache["flat"]
self._cache["flat"] = "\n".join(
result: str = self._cache["flat"]
return result
flat: str = "\n".join(
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
)
return self._cache["flat"]
self._cache["flat"] = flat
return flat

@property
def markdown(self) -> str:
Expand Down Expand Up @@ -332,7 +334,8 @@ def file_to_path(self) -> dict[str, str]:

"""
try:
return self._cache["file_to_path"]
cached: dict[str, str] = self._cache["file_to_path"]
return cached
except KeyError:
mapping = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
self._cache["file_to_path"] = mapping
Expand Down Expand Up @@ -494,8 +497,8 @@ def _normalize_path_for_comparison(path: Path) -> str:
# Only lowercase on Windows where filesystem is case-insensitive
return resolved.lower() if sys.platform == "win32" else resolved

def __iter__(self) -> Iterator[TestFile]:
return iter(self.test_files)
def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058
yield from self.test_files

def __len__(self) -> int:
return len(self.test_files)
Expand All @@ -514,9 +517,9 @@ class CandidateEvaluationContext:
optimized_runtimes: dict[str, float | None] = Field(default_factory=dict)
is_correct: dict[str, bool] = Field(default_factory=dict)
optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict)
ast_code_to_id: dict = Field(default_factory=dict)
ast_code_to_id: dict[str, Any] = Field(default_factory=dict)
optimizations_post: dict[str, str] = Field(default_factory=dict)
valid_optimizations: list = Field(default_factory=list)
valid_optimizations: list[Any] = Field(default_factory=list)

def record_failed_candidate(self, optimization_id: str) -> None:
"""Record results for a failed candidate."""
Expand All @@ -543,7 +546,7 @@ def handle_duplicate_candidate(
# Copy results from the previous evaluation (use .get() in case past_opt_id was registered
# but never benchmarked due to an unhandled exception in process_single_candidate)
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios.get(past_opt_id)
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id)
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id, False)
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes.get(past_opt_id)

# Line profiler results only available for successful runs
Expand Down Expand Up @@ -631,7 +634,7 @@ class OriginalCodeBaseline(BaseModel):
behavior_test_results: TestResults
benchmarking_test_results: TestResults
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
line_profile_results: dict
line_profile_results: dict[str, Any]
runtime: int
coverage_results: Optional[CoverageData]
async_throughput: Optional[int] = None
Expand Down Expand Up @@ -793,7 +796,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]:
f"// Testing function: {self.function_getting_tested}"
)

if self.test_class_name:
if self.test_class_name and self.test_function_name:
for stmt in module_node.body:
if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name:
func_node = self.find_func_in_class(stmt, self.test_function_name)
Expand Down Expand Up @@ -884,7 +887,7 @@ def group_by_benchmarks(
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
from codeflash.code_utils.code_utils import module_name_from_file_path

test_results_by_benchmark = defaultdict(TestResults)
test_results_by_benchmark: defaultdict[BenchmarkKey, TestResults] = defaultdict(TestResults)
benchmark_module_path = {}
for benchmark_key in benchmark_keys:
benchmark_module_path[benchmark_key] = module_name_from_file_path(
Expand Down Expand Up @@ -1015,7 +1018,7 @@ def effective_loop_count(self) -> int:
return max(loop_indices) if loop_indices else 0

def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
map_gen_test_file_to_no_of_tests = Counter()
map_gen_test_file_to_no_of_tests: Counter[Path] = Counter()
for gen_test_result in self.test_results:
if (
gen_test_result.test_type == TestType.GENERATED_REGRESSION
Expand All @@ -1024,8 +1027,8 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa
map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1
return map_gen_test_file_to_no_of_tests

def __iter__(self) -> Iterator[FunctionTestInvocation]:
return iter(self.test_results)
def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058
yield from self.test_results

def __len__(self) -> int:
return len(self.test_results)
Expand All @@ -1051,7 +1054,7 @@ def __eq__(self, other: object) -> bool:
if len(self) != len(other):
return False
original_recursion_limit = sys.getrecursionlimit()
cast("TestResults", other)
assert isinstance(other, TestResults)
for test_result in self:
other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id)
if other_test_result is None:
Expand Down
Loading