Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions codeflash/code_utils/line_profile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import re
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Union
Expand All @@ -15,6 +16,28 @@
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext

# Regex pattern to detect JIT compilation decorators from numba, torch, tensorflow, and jax
JIT_DECORATOR_PATTERN = re.compile(
r"@(?:"
# numba decorators
r"(?:numba\.)?(?:jit|njit|vectorize|guvectorize|stencil|cfunc|generated_jit)"
r"|numba\.cuda\.jit"
r"|cuda\.jit"
# torch decorators
r"|torch\.compile"
r"|torch\.jit\.(?:script|trace)"
# tensorflow decorators
r"|(?:tf|tensorflow)\.function"
# jax decorators
r"|jax\.jit"
r")"
)


def contains_jit_decorator(code: str) -> bool:
"""Check if the code contains JIT compilation decorators from numba, torch, tensorflow, or jax."""
return bool(JIT_DECORATOR_PATTERN.search(code))


class LineProfilerDecoratorAdder(cst.CSTTransformer):
"""Transformer that adds a decorator to a function with a specific qualified name."""
Expand Down
19 changes: 18 additions & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
from codeflash.code_utils.git_utils import git_root_dir
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.line_profile_utils import add_decorator_imports
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.context import code_context_extractor
Expand Down Expand Up @@ -2412,6 +2412,23 @@ def get_test_env(
def line_profiler_step(
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
) -> dict:
# Check if candidate code contains JIT decorators - line profiler doesn't work with JIT compiled code
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
if contains_jit_decorator(candidate_fto_code):
logger.info(
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
)
return {"timings": {}, "unit": 0, "str_out": ""}

# Check helper code for JIT decorators
for module_abspath in original_helper_code:
candidate_helper_code = Path(module_abspath).read_text("utf-8")
if contains_jit_decorator(candidate_helper_code):
logger.info(
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
)
return {"timings": {}, "unit": 0, "str_out": ""}

try:
console.rule()

Expand Down
Loading