Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Jan 7, 2026

⚡️ This pull request contains optimizations for PR #1015

If you approve this dependent PR, these changes will be merged into the original PR branch gpu-sync-instrumentation.

This PR will be automatically closed if the original PR is merged.


📄 55% (0.55x) speedup for detect_frameworks_from_code in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 16.8 milliseconds 10.8 milliseconds (best of 5 runs)

📝 Explanation and details

The optimized code achieves a 54% speedup by replacing ast.walk() with a more efficient manual tree traversal using a deque. Here's why this is faster:

Key Optimizations

  1. Selective Node Traversal: Instead of visiting every node in the AST (which ast.walk() does), the optimized version only descends into node attributes that can contain statements (body, orelse, finalbody, handlers). This skips irrelevant subtrees like expression nodes, literals, and operators that can never contain import statements.

  2. Early Type Checking: By checking isinstance(node, ast.Import) and isinstance(node, ast.ImportFrom) upfront and handling only these node types, the code avoids the overhead of traversing and type-checking thousands of irrelevant nodes that ast.walk() would visit.

  3. Framework Key Lookup Optimization: Using if module_name in framework_keys (tuple membership test) instead of three separate string equality checks reduces comparison overhead, especially when most modules are not frameworks.

Performance Impact by Test Type

  • Basic imports (20-40μs): 33-46% faster - The selective traversal immediately finds import nodes at the module level
  • Nested imports (functions/classes): 35-76% faster - Dramatic gains because we skip traversing expression nodes inside function/class bodies
  • Large codebases (1-5ms): 48-76% faster - The benefits compound significantly with more code, as we avoid visiting thousands of irrelevant nodes

The line profiler confirms the win: ast.walk() consumed 68% of runtime (56ms) in the original, while the optimized manual traversal spends only 3.1% (0.9ms) on the queue loop itself.

Context from Function References

The function is called from inject_profiling_into_existing_test(), which processes test files during optimization workflows. Since this runs on every test file being analyzed, the 54% speedup directly reduces overall test instrumentation time, making the optimization valuable for the profiling pipeline.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 63 Passed
🌀 Generated Regression Tests 39 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
⚙️ Click to see Existing Unit Tests
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
test_instrument_tests.py::TestDetectFrameworksFromCode.test_alias_import_takes_precedence_over_from_import 30.9μs 23.6μs 31.1%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_from_import_no_alias 23.8μs 17.9μs 33.0%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_jax_standard_import 20.3μs 14.9μs 36.2%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_jax_with_alias 21.3μs 15.5μs 37.2%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_multiple_frameworks 31.4μs 23.3μs 34.8%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_no_frameworks 27.7μs 20.0μs 38.3%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_syntax_error_returns_empty 27.4μs 26.4μs 3.75%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_tensorflow_standard_import 20.3μs 15.3μs 32.1%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_tensorflow_with_alias 21.5μs 16.2μs 32.6%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_torch_standard_import 23.8μs 18.0μs 31.7%✅
test_instrument_tests.py::TestDetectFrameworksFromCode.test_detect_torch_with_alias 22.5μs 17.0μs 32.6%✅
🌀 Click to see Generated Regression Tests
from codeflash.code_utils.instrument_existing_tests import detect_frameworks_from_code

# -------------------------
# Basic Test Cases
# -------------------------


def test_no_imports_returns_empty_dict():
    """Test that code with no imports returns an empty dict."""
    code = "print('hello world')"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 35.1μs -> 25.6μs (37.2% faster)


def test_import_torch_basic():
    """Test detection of torch import."""
    code = "import torch"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 22.9μs -> 17.2μs (33.2% faster)


def test_import_tensorflow_basic():
    """Test detection of tensorflow import."""
    code = "import tensorflow"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 21.0μs -> 15.2μs (38.2% faster)


def test_import_jax_basic():
    """Test detection of jax import."""
    code = "import jax"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 20.9μs -> 14.6μs (43.1% faster)


def test_import_torch_with_alias():
    """Test detection of torch import with alias."""
    code = "import torch as th"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 22.1μs -> 15.3μs (44.4% faster)


def test_import_tensorflow_with_alias():
    """Test detection of tensorflow import with alias."""
    code = "import tensorflow as tf"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 21.0μs -> 15.0μs (40.0% faster)


def test_import_jax_with_alias():
    """Test detection of jax import with alias."""
    code = "import jax as j"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 21.8μs -> 14.9μs (46.1% faster)


def test_import_multiple_frameworks():
    """Test detection of multiple framework imports."""
    code = "import torch as th\nimport tensorflow as tf\nimport jax"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 32.0μs -> 22.8μs (40.2% faster)


def test_import_from_torch():
    """Test detection of 'from torch import ...'."""
    code = "from torch import nn"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 23.1μs -> 16.8μs (37.4% faster)


def test_import_from_tensorflow():
    """Test detection of 'from tensorflow import ...'."""
    code = "from tensorflow import keras"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 23.2μs -> 16.4μs (41.5% faster)


def test_import_from_jax():
    """Test detection of 'from jax import random'" """
    code = "from jax import random"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 22.8μs -> 16.1μs (41.4% faster)


def test_import_from_submodule():
    """Test detection of submodule import (from torch.nn import ...)."""
    code = "from torch.nn import Linear"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 24.5μs -> 17.3μs (41.6% faster)


def test_import_multiple_names_in_one_import():
    """Test detection when multiple names are imported in a single statement."""
    code = "import torch, tensorflow as tf, jax"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 27.3μs -> 19.1μs (43.1% faster)


# -------------------------
# Edge Test Cases
# -------------------------


def test_invalid_python_syntax():
    """Test that invalid python code returns empty dict."""
    code = "import torch as th\nthis is not valid python"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 39.4μs -> 38.4μs (2.64% faster)


def test_import_irrelevant_module():
    """Test that irrelevant imports are ignored."""
    code = "import numpy as np\nimport os"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 27.5μs -> 20.0μs (37.7% faster)


def test_import_from_irrelevant_module():
    """Test that from-imports from irrelevant modules are ignored."""
    code = "from numpy import array"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 23.4μs -> 16.7μs (40.4% faster)


def test_import_framework_with_submodule_and_alias():
    """Test import with submodule and alias (should only use top-level module)."""
    code = "import torch.cuda as tc"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 24.1μs -> 17.7μs (36.5% faster)


def test_import_framework_with_submodule_no_alias():
    """Test import with submodule and no alias."""
    code = "import tensorflow.keras"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 22.4μs -> 16.4μs (37.0% faster)


def test_import_framework_multiple_times():
    """Test that only the first from-import is used if multiple from-imports."""
    code = "from torch import nn\nfrom torch import optim"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 29.1μs -> 21.1μs (37.6% faster)


def test_import_framework_and_from_import():
    """Test import torch as th and from torch import nn."""
    code = "import torch as th\nfrom torch import nn"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 28.4μs -> 19.8μs (43.8% faster)


def test_import_framework_with_comments_and_whitespace():
    """Test import with comments and leading/trailing whitespace."""
    code = "  import torch as th  # This is torch\n\n"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 21.1μs -> 20.7μs (1.84% faster)


def test_import_framework_in_function_scope():
    """Test import inside a function (should still be detected)."""
    code = """
def foo():
    import torch as th
    return th
"""
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 41.7μs -> 31.0μs (34.6% faster)


def test_import_framework_in_class_scope():
    """Test import inside a class (should still be detected)."""
    code = """
class Foo:
    def __init__(self):
        import tensorflow as tf
        self.tf = tf
"""
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 59.8μs -> 40.7μs (47.1% faster)


def test_import_framework_with_weird_casing():
    """Test that import with weird casing is NOT detected (case-sensitive)."""
    code = "import Torch as th"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 21.6μs -> 16.4μs (32.2% faster)


def test_import_framework_as_keyword():
    """Test import where alias is a Python keyword."""
    code = "import torch as for"
    # This is a SyntaxError, so should return {}
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 23.0μs -> 23.1μs (0.432% slower)


def test_import_framework_as_none():
    """Test import with alias 'None' (should be a SyntaxError)."""
    code = "import torch as None"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 22.6μs -> 22.0μs (2.82% faster)


def test_import_framework_with_unicode():
    """Test import with unicode alias."""
    code = "import torch as трч"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 27.9μs -> 22.2μs (25.8% faster)


def test_import_framework_with_line_continuation():
    """Test import statement split across lines with backslash."""
    code = "import torch as th, \\\n    tensorflow as tf"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 25.8μs -> 17.4μs (48.9% faster)


def test_import_framework_with_semicolon():
    """Test import statements separated by semicolon."""
    code = "import torch as th; import tensorflow as tf"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 26.9μs -> 19.5μs (38.2% faster)


def test_import_framework_with_multiline_string():
    """Test code with a multiline string containing 'import torch' (should not be detected)."""
    code = '''
"""
import torch
"""
print("hello")
'''
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 38.1μs -> 27.4μs (39.0% faster)


def test_import_framework_with_inline_string():
    """Test code with an inline string containing 'import torch' (should not be detected)."""
    code = 'print("import torch")'
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 30.4μs -> 21.3μs (42.6% faster)


def test_import_framework_with_comment():
    """Test code with import in a comment (should not be detected)."""
    code = "# import torch as th"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 12.7μs -> 9.40μs (35.0% faster)


# -------------------------
# Large Scale Test Cases
# -------------------------


def test_large_number_of_irrelevant_imports():
    """Test code with many irrelevant imports and one relevant import."""
    code = "\n".join([f"import module{i}" for i in range(500)]) + "\nimport torch as th"
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 2.14ms -> 1.45ms (47.6% faster)


def test_large_number_of_framework_imports():
    """Test code with many repeated framework imports (should only detect first alias for each)."""
    code_lines = []
    for i in range(100):
        code_lines.append(f"import torch as th{i}")
        code_lines.append(f"import tensorflow as tf{i}")
        code_lines.append(f"import jax as j{i}")
    code = "\n".join(code_lines)
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 1.21ms -> 815μs (48.6% faster)


def test_large_code_with_imports_in_functions():
    """Test large code base with framework imports inside many functions."""
    code_lines = []
    for i in range(100):
        code_lines.append(f"def foo_{i}():\n    import torch as th{i}\n    return th{i}")
    code = "\n".join(code_lines)
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 1.45ms -> 863μs (67.6% faster)


def test_large_code_with_imports_in_classes():
    """Test large code base with framework imports inside many classes."""
    code_lines = []
    for i in range(100):
        code_lines.append(
            f"class Foo{i}:\n    def __init__(self):\n        import tensorflow as tf{i}\n        self.tf = tf{i}"
        )
    code = "\n".join(code_lines)
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 2.56ms -> 1.45ms (76.3% faster)


def test_large_code_with_frameworks_and_submodules():
    """Test large code base with mixture of import and from-import for frameworks and submodules."""
    code_lines = []
    for i in range(100):
        code_lines.append(f"import torch.nn as nn{i}")
        code_lines.append("from tensorflow.keras import Model")
        code_lines.append("from jax.random import PRNGKey")
    code = "\n".join(code_lines)
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 1.37ms -> 918μs (48.7% faster)


def test_large_code_with_no_frameworks():
    """Test large code base with no relevant framework imports."""
    code_lines = [f"import module{i} as m{i}" for i in range(1000)]
    code = "\n".join(code_lines)
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 4.79ms -> 3.14ms (52.5% faster)


def test_large_code_with_mixed_content():
    """Test large code base with irrelevant imports, framework imports, functions, and classes."""
    code_lines = [f"import module{i} as m{i}" for i in range(500)]
    code_lines += [
        "def foo():",
        "    import torch as th",
        "    return th",
        "class Bar:",
        "    def __init__(self):",
        "        import tensorflow as tf",
        "        self.tf = tf",
        "from jax import random",
        "print('done')",
    ]
    code = "\n".join(code_lines)
    codeflash_output = detect_frameworks_from_code(code)
    result = codeflash_output  # 2.12ms -> 1.33ms (59.3% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1015-2026-01-07T22.54.48 and push.

Codeflash Static Badge

The optimized code achieves a **54% speedup** by replacing `ast.walk()` with a more efficient manual tree traversal using a `deque`. Here's why this is faster:

## Key Optimizations

1. **Selective Node Traversal**: Instead of visiting *every* node in the AST (which `ast.walk()` does), the optimized version only descends into node attributes that can contain statements (`body`, `orelse`, `finalbody`, `handlers`). This skips irrelevant subtrees like expression nodes, literals, and operators that can never contain import statements.

2. **Early Type Checking**: By checking `isinstance(node, ast.Import)` and `isinstance(node, ast.ImportFrom)` upfront and handling only these node types, the code avoids the overhead of traversing and type-checking thousands of irrelevant nodes that `ast.walk()` would visit.

3. **Framework Key Lookup Optimization**: Using `if module_name in framework_keys` (tuple membership test) instead of three separate string equality checks reduces comparison overhead, especially when most modules are not frameworks.

## Performance Impact by Test Type

- **Basic imports** (20-40μs): 33-46% faster - The selective traversal immediately finds import nodes at the module level
- **Nested imports** (functions/classes): 35-76% faster - Dramatic gains because we skip traversing expression nodes inside function/class bodies  
- **Large codebases** (1-5ms): 48-76% faster - The benefits compound significantly with more code, as we avoid visiting thousands of irrelevant nodes

The line profiler confirms the win: `ast.walk()` consumed 68% of runtime (56ms) in the original, while the optimized manual traversal spends only 3.1% (0.9ms) on the queue loop itself.

## Context from Function References

The function is called from `inject_profiling_into_existing_test()`, which processes test files during optimization workflows. Since this runs on every test file being analyzed, the 54% speedup directly reduces overall test instrumentation time, making the optimization valuable for the profiling pipeline.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Jan 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant