Skip to content

Conversation

@KRRT7
Copy link
Collaborator

@KRRT7 KRRT7 commented Jan 7, 2026

Summary

  • Extracts class definitions from project modules for types referenced in import statements
  • Helps LLM understand constructor signatures, base classes, and class structure
  • Fixes issues where LLM generates incorrect constructor calls (e.g., Element(text="...") when Element is abstract)

Problem

When generating tests for functions like will_fit in unstructured, the LLM only saw:

from unstructured.documents.elements import Element

Without the actual Element class definition, the LLM incorrectly guessed:

  • Constructor arguments (e.g., Element(text="...") fails)
  • That Element is abstract and can't be instantiated directly

Solution

Now the testgen context includes the actual class definitions:

import abc

class Element(abc.ABC):
    def __init__(self, element_id: str = None):
        self._element_id = element_id
        self.text = ""
    ...

Changes

  • Add get_imported_class_definitions() function to code_context_extractor.py
  • Integrate into get_code_optimization_context() to include extracted classes
  • Handle token limits gracefully (drops class definitions if over limit)
  • Add 4 unit tests

Test plan

  • All 35 context extractor tests pass
  • Verified extraction works with unstructured codebase (Element class extracted correctly)
  • Manual test: uv run codeflash --file unstructured/chunking/base.py --function will_fit

KRRT7 and others added 4 commits January 7, 2026 16:09
When generating tests, the LLM now receives class definitions for
types imported from project modules. This helps the LLM understand:
- Constructor signatures (avoiding incorrect argument guessing)
- Base classes (e.g., abstract classes that can't be instantiated)
- Class structure for creating proper test instances

Previously, the LLM only saw import statements like:
  from mypackage.elements import Element

Now it also sees the actual class definition with constructor details.

Changes:
- Add get_imported_class_definitions() to extract class definitions
  from project modules referenced in import statements
- Integrate into get_code_optimization_context() to include extracted
  classes in testgen context
- Gracefully handle token limits by dropping class definitions if needed
- Add 4 unit tests covering extraction, deduplication, and filtering
Comment on lines +640 to +657
# Find imports that provide these names
import_lines: list[str] = []
source_lines = module_source.split("\n")

for node in module_tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name.split(".")[0]
if name in needed_names:
import_lines.append(source_lines[node.lineno - 1])
break
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
if name in needed_names:
import_lines.append(source_lines[node.lineno - 1])
break

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 70% (0.70x) speedup for _extract_imports_for_class in codeflash/context/code_context_extractor.py

⏱️ Runtime : 5.37 milliseconds 3.16 milliseconds (best of 5 runs)

⚡️ This change will improve the performance of the following benchmarks:

Benchmark File :: Function Original Runtime Expected New Runtime Speedup
tests.benchmarks.test_benchmark_code_extract_code_context::test_benchmark_extract 44.3 seconds 44.3 seconds 0.00%

🔻 This change will degrade the performance of the following benchmarks:

{benchmark_info_degraded}

📝 Explanation and details

The optimized code achieves a 70% speedup through three key optimizations that work synergistically:

Key Optimizations

1. Early Exit for Empty Base Classes

if not needed_names:
    return ""

When a class has no base classes (or only built-in bases like object), the function immediately returns without scanning the AST or splitting the source. This optimization shows dramatic gains in test cases like test_no_base_class (91% faster) and test_performance_with_many_unused_imports (3432% faster).

2. Early Termination via Tracking Remaining Names

The optimized version maintains a remaining_names set that shrinks as imports are found:

remaining_names = needed_names.copy()
# ... during iteration ...
if not remaining_names:
    break

This allows the loop to exit as soon as all needed imports are found, rather than scanning the entire AST. For large modules with many imports at the top, this cuts iteration count roughly in half (29,681 → 14,715 hits on the main loop).

3. Deferred Source Splitting

The original code splits module_source into lines immediately, costing ~6% of runtime even when no imports are found. The optimized version:

  • First collects line numbers of matching imports
  • Only splits the source if imports were actually found
  • Uses list comprehension for efficient line extraction
# Only split when needed
if not import_line_numbers:
    return ""
source_lines = module_source.split("\n")
import_lines = [source_lines[line_num] for line_num in import_line_numbers]

Performance Characteristics

Based on the annotated tests, the optimization excels when:

  • Classes have no base classes (91-3432% faster) - early exit avoids all work
  • Many imports but few used (66-182% faster) - early termination stops scanning quickly
  • Large-scale scenarios (42-120% faster) - cumulative effect of all optimizations

The optimization is slightly slower (12-35%) for small, simple cases due to the overhead of copying needed_names and checking remaining_names. However, these cases are extremely fast in absolute terms (microseconds), making the slowdown negligible.

Impact on Workloads

Looking at the function_references, this function is called from get_imported_class_definitions(), which processes every imported class name in a code context. The optimization is particularly beneficial because:

  1. Many classes inherit from built-ins or have no bases → early exit provides major gains
  2. Real-world modules often have many imports at the top → early termination saves significant work
  3. The function is called repeatedly in a loop → per-call savings compound

The optimization transforms this from an O(imports × classes) operation to one that terminates early in most practical cases, making it well-suited for the hot path where it's used to extract class context for code analysis.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 344 Passed
⏪ Replay Tests 11 Passed
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
from __future__ import annotations

import ast

# imports
from codeflash.context.code_context_extractor import _extract_imports_for_class

# unit tests

# Basic Test Cases


def test_single_import_for_base_class():
    # Class inherits from Foo, which is imported
    source = "import Foo\nclass Bar(Foo):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 4.23μs -> 5.74μs (26.4% slower)


def test_import_from_for_base_class():
    # Class inherits from Baz, which is imported from a module
    source = "from mod import Baz\nclass Qux(Baz):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.09μs -> 4.62μs (33.0% slower)


def test_multiple_base_classes():
    # Class inherits from multiple bases, both imported
    source = "from mod import Baz\nimport Foo\nclass Qux(Foo, Baz):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 4.05μs -> 5.45μs (25.7% slower)
    lines = result.split("\n")


def test_import_with_asname():
    # Class inherits from Foo, which is imported as Bar
    source = "import Foo as Bar\nclass Baz(Bar):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.96μs -> 4.13μs (28.5% slower)


def test_importfrom_with_asname():
    # Class inherits from Baz, which is imported as Qux
    source = "from mod import Baz as Qux\nclass Foo(Qux):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.95μs -> 4.27μs (30.9% slower)


def test_attribute_base_class():
    # Class inherits from abc.ABC, which is imported
    source = "import abc\nclass MyClass(abc.ABC):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.18μs -> 4.62μs (31.2% slower)


def test_importfrom_attribute_base_class():
    # Class inherits from foo.Bar, foo is imported from somewhere
    source = "from baz import foo\nclass MyClass(foo.Bar):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.17μs -> 4.57μs (30.5% slower)


def test_no_import_needed():
    # Class inherits from built-in object, no import needed
    source = "class MyClass(object):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.14μs -> 2.67μs (19.8% slower)


def test_no_base_class():
    # Class does not inherit from anything
    source = "class MyClass:\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 1.89μs -> 988ns (90.8% faster)


# Edge Test Cases


def test_import_not_included_if_not_base():
    # Foo is imported, but not used as base class
    source = "import Foo\nclass Bar:\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.59μs -> 917ns (182% faster)


def test_multiple_imports_with_similar_names():
    # Both Foo and FooBar imported, only Foo used as base
    source = "import Foo\nimport FooBar\nclass Bar(Foo):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.53μs -> 4.56μs (22.7% slower)


def test_importfrom_multiple_names():
    # from mod import Foo, Bar; only Foo used as base
    source = "from mod import Foo, Bar\nclass Baz(Foo):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.79μs -> 4.30μs (35.1% slower)


def test_importfrom_with_asname_and_normal():
    # from mod import Foo as Bar, Baz; Bar used as base
    source = "from mod import Foo as Bar, Baz\nclass Qux(Bar):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.87μs -> 4.28μs (33.0% slower)


def test_import_with_dotted_name():
    # import foo.bar; base class is bar (should not match)
    source = "import foo.bar\nclass Baz(bar):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.83μs -> 3.23μs (12.3% slower)


def test_importfrom_star():
    # from mod import *; base class is Foo (should not match)
    source = "from mod import *\nclass Baz(Foo):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.72μs -> 2.63μs (3.50% faster)


def test_import_with_comment():
    # Import line has a comment, should still match
    source = "import Foo  # comment here\nclass Bar(Foo):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.15μs -> 4.16μs (24.1% slower)


def test_importfrom_with_comment():
    # ImportFrom line has a comment
    source = "from mod import Foo  # another comment\nclass Bar(Foo):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.86μs -> 4.12μs (30.6% slower)


def test_base_class_is_attribute_of_attribute():
    # Class inherits from foo.bar.Baz, only foo imported
    source = "import foo\nclass MyClass(foo.bar.Baz):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    # Should extract import for 'foo'
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.80μs -> 1.18μs (139% faster)


def test_base_class_is_attribute_of_attribute_from_import():
    # Class inherits from foo.bar.Baz, foo imported from somewhere
    source = "from baz import foo\nclass MyClass(foo.bar.Baz):\n    pass\n"
    tree = ast.parse(source)
    class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
    # Should extract import for 'foo'
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.71μs -> 1.32μs (106% faster)


# Large Scale Test Cases


def test_many_imports_and_classes():
    # Many imports and classes, only some imports relevant
    imports = [f"import Foo{i}" for i in range(50)]
    classes = [f"class Bar{i}(Foo{i}):\n    pass" for i in range(50)]
    source = "\n".join(imports + classes)
    tree = ast.parse(source)
    for i in range(50):
        class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Bar{i}")
        codeflash_output = _extract_imports_for_class(tree, class_node, source)
        result = codeflash_output  # 915μs -> 529μs (73.1% faster)


def test_many_importfrom_and_classes():
    # Many from-imports and classes, only some imports relevant
    imports = [f"from mod{i} import Baz{i}" for i in range(50)]
    classes = [f"class Qux{i}(Baz{i}):\n    pass" for i in range(50)]
    source = "\n".join(imports + classes)
    tree = ast.parse(source)
    for i in range(50):
        class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Qux{i}")
        codeflash_output = _extract_imports_for_class(tree, class_node, source)
        result = codeflash_output  # 873μs -> 510μs (71.0% faster)


def test_large_mixed_imports_and_classes():
    # Mix of import, importfrom, asname, and many classes
    imports = []
    classes = []
    for i in range(25):
        imports.append(f"import Foo{i} as Bar{i}")
        classes.append(f"class Baz{i}(Bar{i}):\n    pass")
    for i in range(25, 50):
        imports.append(f"from mod{i} import Baz{i} as Qux{i}")
        classes.append(f"class Foo{i}(Qux{i}):\n    pass")
    source = "\n".join(imports + classes)
    tree = ast.parse(source)
    for i in range(25):
        class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Baz{i}")
        codeflash_output = _extract_imports_for_class(tree, class_node, source)
        result = codeflash_output  # 432μs -> 196μs (120% faster)
    for i in range(25, 50):
        class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Foo{i}")
        codeflash_output = _extract_imports_for_class(tree, class_node, source)
        result = codeflash_output  # 421μs -> 296μs (42.0% faster)


def test_large_scale_attribute_base_classes():
    # Many classes inheriting from mod.FooX, mod imported via from-import
    imports = ["from mod import mod"]
    classes = [f"class Bar{i}(mod.Foo{i}):\n    pass" for i in range(50)]
    source = "\n".join(imports + classes)
    tree = ast.parse(source)
    for i in range(50):
        class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Bar{i}")
        codeflash_output = _extract_imports_for_class(tree, class_node, source)
        result = codeflash_output  # 371μs -> 215μs (72.5% faster)


def test_large_scale_no_imports_needed():
    # Many classes inheriting from 'object', no imports needed
    source = "\n".join([f"class Foo{i}(object):\n    pass" for i in range(100)])
    tree = ast.parse(source)
    for i in range(100):
        class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Foo{i}")
        codeflash_output = _extract_imports_for_class(tree, class_node, source)
        result = codeflash_output  # 1.28ms -> 809μs (57.6% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import ast  # used to create AST nodes for testing

# imports
from codeflash.context.code_context_extractor import _extract_imports_for_class

# Basic Test Cases


def test_single_base_class_with_simple_import():
    """Test extracting import for a class with one base class from simple import."""
    # Create source code with import and class definition
    source = "import BaseClass\n\nclass MyClass(BaseClass):\n    pass"
    # Parse the source into an AST
    tree = ast.parse(source)
    # Get the class node (second item in body after import)
    class_node = tree.body[1]

    # Call function and verify it extracts the import
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.21μs -> 4.14μs (22.4% slower)


def test_single_base_class_with_from_import():
    """Test extracting import for a class with base class from 'from...import' statement."""
    # Create source with from-import statement
    source = "from module import BaseClass\n\nclass MyClass(BaseClass):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Verify the from-import is extracted
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.00μs -> 4.27μs (29.8% slower)


def test_multiple_base_classes_different_imports():
    """Test extracting multiple imports for class with multiple base classes."""
    # Create source with two imports and class inheriting from both
    source = "import Base1\nfrom module import Base2\n\nclass MyClass(Base1, Base2):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[2]

    # Both imports should be extracted
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.94μs -> 5.37μs (26.5% slower)


def test_no_base_classes():
    """Test that no imports are extracted when class has no base classes."""
    # Create source with import but class doesn't use it
    source = "import SomeModule\n\nclass MyClass:\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Should return empty string since no base classes
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.40μs -> 978ns (146% faster)


def test_unused_imports_not_extracted():
    """Test that imports not used by base classes are not extracted."""
    # Create source with multiple imports but only one used
    source = "import Unused1\nimport Unused2\nfrom module import BaseClass\n\nclass MyClass(BaseClass):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[3]

    # Only the used import should be extracted
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.89μs -> 4.86μs (20.0% slower)


# Edge Test Cases


def test_attribute_base_class():
    """Test extracting import for base class specified as attribute (e.g., abc.ABC)."""
    # Create source with module import and class using module.Class syntax
    source = "import abc\n\nclass MyClass(abc.ABC):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Should extract the module import
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.20μs -> 4.36μs (26.6% slower)


def test_import_with_alias():
    """Test extracting import when base class is imported with alias."""
    # Create source with aliased import
    source = "import original_module as om\n\nclass MyClass(om):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Should extract the aliased import
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.94μs -> 3.91μs (24.7% slower)


def test_from_import_with_alias():
    """Test extracting from-import when base class has alias."""
    # Create source with aliased from-import
    source = "from module import OriginalClass as OC\n\nclass MyClass(OC):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Should extract the aliased from-import
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.08μs -> 4.10μs (25.0% slower)


def test_empty_module():
    """Test handling of empty module with no imports."""
    # Create minimal source with just a class
    source = "class MyClass:\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[0]

    # Should return empty string
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 1.89μs -> 1.03μs (82.8% faster)


def test_base_class_without_import():
    """Test class with base class that has no corresponding import."""
    # Create source where base class is not imported (might be builtin or error)
    source = "class MyClass(SomeClass):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[0]

    # Should return empty string since no matching import
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.10μs -> 2.28μs (7.94% slower)


def test_dotted_import_name():
    """Test import with dotted name (e.g., import os.path)."""
    # Create source with dotted import
    source = "import os.path\n\nclass MyClass(os):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Should extract the import, matching on first part of dotted name
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.21μs -> 4.59μs (29.9% slower)


def test_multiple_names_in_single_import():
    """Test import statement with multiple names where only one is used."""
    # Create source with comma-separated imports
    source = "from module import Class1, Class2, Class3\n\nclass MyClass(Class2):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Should extract the entire import line
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.27μs -> 4.38μs (25.3% slower)


def test_nested_attribute_base_class():
    """Test base class with nested attributes (e.g., module.submodule.Class)."""
    # Create source with import and nested attribute usage
    source = "import collections.abc\n\nclass MyClass(collections.abc):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Should extract import matching the first part of attribute chain
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.31μs -> 4.65μs (28.7% slower)


def test_multiple_classes_same_import():
    """Test that same import is extracted for different classes using it."""
    # Create source with one import and two classes using it
    source = "from module import Base\n\nclass Class1(Base):\n    pass\n\nclass Class2(Base):\n    pass"
    tree = ast.parse(source)

    # Test first class
    class_node1 = tree.body[1]
    codeflash_output = _extract_imports_for_class(tree, class_node1, source)
    result1 = codeflash_output  # 3.17μs -> 4.37μs (27.4% slower)

    # Test second class
    class_node2 = tree.body[2]
    codeflash_output = _extract_imports_for_class(tree, class_node2, source)
    result2 = codeflash_output  # 1.44μs -> 1.92μs (25.0% slower)


def test_multiline_source_correct_line_extraction():
    """Test that correct line is extracted from multiline source."""
    # Create source with multiple lines and imports at different positions
    source = "# Comment\nimport Base1\n\nfrom module import Base2\n\nclass MyClass(Base2):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[2]  # Third statement (after two imports)

    # Should extract the correct line
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 3.48μs -> 4.70μs (26.0% slower)


def test_import_star():
    """Test handling of 'from module import *' statement."""
    # Create source with star import
    source = "from module import *\n\nclass MyClass(SomeBase):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Star imports have special handling - check if extracted
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.77μs -> 2.63μs (5.16% faster)


def test_relative_import():
    """Test handling of relative imports (from . import or from .. import)."""
    # Create source with relative import
    source = "from .module import BaseClass\n\nclass MyClass(BaseClass):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Should extract relative import
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.80μs -> 4.11μs (31.9% slower)


def test_base_class_as_call_not_name():
    """Test that base classes specified as calls (not names) don't cause errors."""
    # Create source where base class is a function call (e.g., Generic[T])
    source = "from typing import Generic\n\nclass MyClass(Generic[int]):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[1]

    # Function calls in bases should be handled gracefully (not extracted)
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 2.56μs -> 1.23μs (109% faster)


# Large Scale Test Cases


def test_many_base_classes():
    """Test class with many base classes requiring many imports."""
    # Create source with 50 different base classes
    num_bases = 50
    imports = [f"from module{i} import Base{i}" for i in range(num_bases)]
    bases = ", ".join([f"Base{i}" for i in range(num_bases)])
    source = "\n".join(imports) + f"\n\nclass MyClass({bases}):\n    pass"

    tree = ast.parse(source)
    class_node = tree.body[num_bases]  # After all imports

    # All imports should be extracted
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 25.1μs -> 29.6μs (15.2% slower)
    result_lines = result.split("\n")
    for i in range(num_bases):
        pass


def test_many_imports_in_module():
    """Test module with many imports where only few are used."""
    # Create source with 100 imports but only 3 used
    num_imports = 100
    imports = [f"from module{i} import Class{i}" for i in range(num_imports)]
    source = "\n".join(imports) + "\n\nclass MyClass(Class10, Class50, Class99):\n    pass"

    tree = ast.parse(source)
    class_node = tree.body[num_imports]

    # Only 3 imports should be extracted
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 29.3μs -> 32.2μs (9.27% slower)
    result_lines = result.split("\n")


def test_large_source_file():
    """Test extraction from large source file with many lines."""
    # Create source with 500 lines including comments, blank lines, and code
    lines = []
    lines.append("import BaseClass")
    for i in range(498):
        lines.append(f"# Comment line {i}")
    lines.append("class MyClass(BaseClass):\n    pass")
    source = "\n".join(lines)

    tree = ast.parse(source)
    # Find the class node (should be last statement)
    class_node = [node for node in tree.body if isinstance(node, ast.ClassDef)][0]

    # Should correctly extract import from line 0
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 38.8μs -> 23.3μs (66.9% faster)


def test_complex_inheritance_hierarchy():
    """Test class with complex mix of import types and base classes."""
    # Create source with various import styles
    source = """import abc
import typing
from collections.abc import Mapping
from typing import Generic, TypeVar
from custom import CustomBase as CB

T = TypeVar('T')

class MyClass(abc.ABC, Mapping, CB):
    pass"""

    tree = ast.parse(source)
    # Find the class node
    class_node = [node for node in tree.body if isinstance(node, ast.ClassDef)][0]

    # Should extract all relevant imports
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 5.91μs -> 7.53μs (21.5% slower)


def test_performance_with_many_unused_imports():
    """Test performance when module has many imports but class uses none."""
    # Create source with 200 imports but class has no bases
    num_imports = 200
    imports = [f"import module{i}" for i in range(num_imports)]
    source = "\n".join(imports) + "\n\nclass MyClass:\n    pass"

    tree = ast.parse(source)
    class_node = tree.body[num_imports]

    # Should quickly return empty string
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 50.7μs -> 1.44μs (3432% faster)


def test_mixed_import_styles_large_scale():
    """Test large module with mixed import styles (import, from-import, aliases)."""
    # Create source with 100 mixed-style imports
    imports = []
    for i in range(100):
        if i % 3 == 0:
            imports.append(f"import module{i}")
        elif i % 3 == 1:
            imports.append(f"from package{i} import Class{i}")
        else:
            imports.append(f"from package{i} import Class{i} as C{i}")

    # Use every 10th import as base class
    bases = []
    expected_imports = []
    for i in range(0, 100, 10):
        if i % 3 == 0:
            bases.append(f"module{i}")
            expected_imports.append(f"import module{i}")
        elif i % 3 == 1:
            bases.append(f"Class{i}")
            expected_imports.append(f"from package{i} import Class{i}")
        else:
            bases.append(f"C{i}")
            expected_imports.append(f"from package{i} import Class{i} as C{i}")

    source = "\n".join(imports) + f"\n\nclass MyClass({', '.join(bases)}):\n    pass"
    tree = ast.parse(source)
    class_node = tree.body[100]

    # Should extract all 10 used imports
    codeflash_output = _extract_imports_for_class(tree, class_node, source)
    result = codeflash_output  # 32.0μs -> 36.6μs (12.7% slower)
    result_lines = result.split("\n")
    for expected in expected_imports:
        pass


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
⏪ Click to see Replay Tests
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
benchmarks/codeflash_replay_tests_fuodcj9h/test_tests_benchmarks_test_benchmark_code_extract_code_context__replay_test_0.py::test_codeflash_context_code_context_extractor__extract_imports_for_class_test_benchmark_extract 786μs 332μs 136%✅

To test or edit this optimization locally git merge codeflash/optimize-pr1014-2026-01-07T21.53.54

Click to see suggested changes
Suggested change
# Find imports that provide these names
import_lines: list[str] = []
source_lines = module_source.split("\n")
for node in module_tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name.split(".")[0]
if name in needed_names:
import_lines.append(source_lines[node.lineno - 1])
break
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
if name in needed_names:
import_lines.append(source_lines[node.lineno - 1])
break
if not needed_names:
return ""
# Find imports that provide these names
import_line_numbers: list[int] = []
remaining_names = needed_names.copy()
for node in module_tree.body:
if not remaining_names:
break
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name.split(".")[0]
if name in remaining_names:
import_line_numbers.append(node.lineno - 1)
remaining_names.discard(name)
break
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
if name in remaining_names:
import_line_numbers.append(node.lineno - 1)
remaining_names.discard(name)
break
if not import_line_numbers:
return ""
source_lines = module_source.split("\n")
import_lines: list[str] = [source_lines[line_num] for line_num in import_line_numbers]

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Jan 7, 2026

⚡️ Codeflash found optimizations for this PR

📄 1,809% (18.09x) speedup for get_test_info_from_stack in codeflash/verification/codeflash_capture.py

⏱️ Runtime : 29.0 milliseconds 1.52 milliseconds (best of 62 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch feat/extract-imported-class-definitions).

Static Badge

@KRRT7 KRRT7 merged commit 5ef4ed4 into main Jan 8, 2026
22 of 23 checks passed
@KRRT7 KRRT7 deleted the feat/extract-imported-class-definitions branch January 8, 2026 19:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants