-
Notifications
You must be signed in to change notification settings - Fork 21
feat: extract imported class definitions for testgen context #1014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When generating tests, the LLM now receives class definitions for types imported from project modules. This helps the LLM understand: - Constructor signatures (avoiding incorrect argument guessing) - Base classes (e.g., abstract classes that can't be instantiated) - Class structure for creating proper test instances Previously, the LLM only saw import statements like: from mypackage.elements import Element Now it also sees the actual class definition with constructor details. Changes: - Add get_imported_class_definitions() to extract class definitions from project modules referenced in import statements - Integrate into get_code_optimization_context() to include extracted classes in testgen context - Gracefully handle token limits by dropping class definitions if needed - Add 4 unit tests covering extraction, deduplication, and filtering
| # Find imports that provide these names | ||
| import_lines: list[str] = [] | ||
| source_lines = module_source.split("\n") | ||
|
|
||
| for node in module_tree.body: | ||
| if isinstance(node, ast.Import): | ||
| for alias in node.names: | ||
| name = alias.asname if alias.asname else alias.name.split(".")[0] | ||
| if name in needed_names: | ||
| import_lines.append(source_lines[node.lineno - 1]) | ||
| break | ||
| elif isinstance(node, ast.ImportFrom): | ||
| for alias in node.names: | ||
| name = alias.asname if alias.asname else alias.name | ||
| if name in needed_names: | ||
| import_lines.append(source_lines[node.lineno - 1]) | ||
| break | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚡️Codeflash found 70% (0.70x) speedup for _extract_imports_for_class in codeflash/context/code_context_extractor.py
⏱️ Runtime : 5.37 milliseconds → 3.16 milliseconds (best of 5 runs)
⚡️ This change will improve the performance of the following benchmarks:
| Benchmark File :: Function | Original Runtime | Expected New Runtime | Speedup |
|---|---|---|---|
| tests.benchmarks.test_benchmark_code_extract_code_context::test_benchmark_extract | 44.3 seconds | 44.3 seconds | 0.00% |
🔻 This change will degrade the performance of the following benchmarks:
{benchmark_info_degraded}
📝 Explanation and details
The optimized code achieves a 70% speedup through three key optimizations that work synergistically:
Key Optimizations
1. Early Exit for Empty Base Classes
if not needed_names:
return ""When a class has no base classes (or only built-in bases like object), the function immediately returns without scanning the AST or splitting the source. This optimization shows dramatic gains in test cases like test_no_base_class (91% faster) and test_performance_with_many_unused_imports (3432% faster).
2. Early Termination via Tracking Remaining Names
The optimized version maintains a remaining_names set that shrinks as imports are found:
remaining_names = needed_names.copy()
# ... during iteration ...
if not remaining_names:
breakThis allows the loop to exit as soon as all needed imports are found, rather than scanning the entire AST. For large modules with many imports at the top, this cuts iteration count roughly in half (29,681 → 14,715 hits on the main loop).
3. Deferred Source Splitting
The original code splits module_source into lines immediately, costing ~6% of runtime even when no imports are found. The optimized version:
- First collects line numbers of matching imports
- Only splits the source if imports were actually found
- Uses list comprehension for efficient line extraction
# Only split when needed
if not import_line_numbers:
return ""
source_lines = module_source.split("\n")
import_lines = [source_lines[line_num] for line_num in import_line_numbers]Performance Characteristics
Based on the annotated tests, the optimization excels when:
- Classes have no base classes (91-3432% faster) - early exit avoids all work
- Many imports but few used (66-182% faster) - early termination stops scanning quickly
- Large-scale scenarios (42-120% faster) - cumulative effect of all optimizations
The optimization is slightly slower (12-35%) for small, simple cases due to the overhead of copying needed_names and checking remaining_names. However, these cases are extremely fast in absolute terms (microseconds), making the slowdown negligible.
Impact on Workloads
Looking at the function_references, this function is called from get_imported_class_definitions(), which processes every imported class name in a code context. The optimization is particularly beneficial because:
- Many classes inherit from built-ins or have no bases → early exit provides major gains
- Real-world modules often have many imports at the top → early termination saves significant work
- The function is called repeatedly in a loop → per-call savings compound
The optimization transforms this from an O(imports × classes) operation to one that terminates early in most practical cases, making it well-suited for the hot path where it's used to extract class context for code analysis.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 344 Passed |
| ⏪ Replay Tests | ✅ 11 Passed |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 📊 Tests Coverage | 100.0% |
🌀 Click to see Generated Regression Tests
from __future__ import annotations
import ast
# imports
from codeflash.context.code_context_extractor import _extract_imports_for_class
# unit tests
# Basic Test Cases
def test_single_import_for_base_class():
# Class inherits from Foo, which is imported
source = "import Foo\nclass Bar(Foo):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 4.23μs -> 5.74μs (26.4% slower)
def test_import_from_for_base_class():
# Class inherits from Baz, which is imported from a module
source = "from mod import Baz\nclass Qux(Baz):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.09μs -> 4.62μs (33.0% slower)
def test_multiple_base_classes():
# Class inherits from multiple bases, both imported
source = "from mod import Baz\nimport Foo\nclass Qux(Foo, Baz):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 4.05μs -> 5.45μs (25.7% slower)
lines = result.split("\n")
def test_import_with_asname():
# Class inherits from Foo, which is imported as Bar
source = "import Foo as Bar\nclass Baz(Bar):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.96μs -> 4.13μs (28.5% slower)
def test_importfrom_with_asname():
# Class inherits from Baz, which is imported as Qux
source = "from mod import Baz as Qux\nclass Foo(Qux):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.95μs -> 4.27μs (30.9% slower)
def test_attribute_base_class():
# Class inherits from abc.ABC, which is imported
source = "import abc\nclass MyClass(abc.ABC):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.18μs -> 4.62μs (31.2% slower)
def test_importfrom_attribute_base_class():
# Class inherits from foo.Bar, foo is imported from somewhere
source = "from baz import foo\nclass MyClass(foo.Bar):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.17μs -> 4.57μs (30.5% slower)
def test_no_import_needed():
# Class inherits from built-in object, no import needed
source = "class MyClass(object):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.14μs -> 2.67μs (19.8% slower)
def test_no_base_class():
# Class does not inherit from anything
source = "class MyClass:\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 1.89μs -> 988ns (90.8% faster)
# Edge Test Cases
def test_import_not_included_if_not_base():
# Foo is imported, but not used as base class
source = "import Foo\nclass Bar:\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.59μs -> 917ns (182% faster)
def test_multiple_imports_with_similar_names():
# Both Foo and FooBar imported, only Foo used as base
source = "import Foo\nimport FooBar\nclass Bar(Foo):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.53μs -> 4.56μs (22.7% slower)
def test_importfrom_multiple_names():
# from mod import Foo, Bar; only Foo used as base
source = "from mod import Foo, Bar\nclass Baz(Foo):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.79μs -> 4.30μs (35.1% slower)
def test_importfrom_with_asname_and_normal():
# from mod import Foo as Bar, Baz; Bar used as base
source = "from mod import Foo as Bar, Baz\nclass Qux(Bar):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.87μs -> 4.28μs (33.0% slower)
def test_import_with_dotted_name():
# import foo.bar; base class is bar (should not match)
source = "import foo.bar\nclass Baz(bar):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.83μs -> 3.23μs (12.3% slower)
def test_importfrom_star():
# from mod import *; base class is Foo (should not match)
source = "from mod import *\nclass Baz(Foo):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.72μs -> 2.63μs (3.50% faster)
def test_import_with_comment():
# Import line has a comment, should still match
source = "import Foo # comment here\nclass Bar(Foo):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.15μs -> 4.16μs (24.1% slower)
def test_importfrom_with_comment():
# ImportFrom line has a comment
source = "from mod import Foo # another comment\nclass Bar(Foo):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.86μs -> 4.12μs (30.6% slower)
def test_base_class_is_attribute_of_attribute():
# Class inherits from foo.bar.Baz, only foo imported
source = "import foo\nclass MyClass(foo.bar.Baz):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
# Should extract import for 'foo'
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.80μs -> 1.18μs (139% faster)
def test_base_class_is_attribute_of_attribute_from_import():
# Class inherits from foo.bar.Baz, foo imported from somewhere
source = "from baz import foo\nclass MyClass(foo.bar.Baz):\n pass\n"
tree = ast.parse(source)
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef))
# Should extract import for 'foo'
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.71μs -> 1.32μs (106% faster)
# Large Scale Test Cases
def test_many_imports_and_classes():
# Many imports and classes, only some imports relevant
imports = [f"import Foo{i}" for i in range(50)]
classes = [f"class Bar{i}(Foo{i}):\n pass" for i in range(50)]
source = "\n".join(imports + classes)
tree = ast.parse(source)
for i in range(50):
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Bar{i}")
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 915μs -> 529μs (73.1% faster)
def test_many_importfrom_and_classes():
# Many from-imports and classes, only some imports relevant
imports = [f"from mod{i} import Baz{i}" for i in range(50)]
classes = [f"class Qux{i}(Baz{i}):\n pass" for i in range(50)]
source = "\n".join(imports + classes)
tree = ast.parse(source)
for i in range(50):
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Qux{i}")
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 873μs -> 510μs (71.0% faster)
def test_large_mixed_imports_and_classes():
# Mix of import, importfrom, asname, and many classes
imports = []
classes = []
for i in range(25):
imports.append(f"import Foo{i} as Bar{i}")
classes.append(f"class Baz{i}(Bar{i}):\n pass")
for i in range(25, 50):
imports.append(f"from mod{i} import Baz{i} as Qux{i}")
classes.append(f"class Foo{i}(Qux{i}):\n pass")
source = "\n".join(imports + classes)
tree = ast.parse(source)
for i in range(25):
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Baz{i}")
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 432μs -> 196μs (120% faster)
for i in range(25, 50):
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Foo{i}")
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 421μs -> 296μs (42.0% faster)
def test_large_scale_attribute_base_classes():
# Many classes inheriting from mod.FooX, mod imported via from-import
imports = ["from mod import mod"]
classes = [f"class Bar{i}(mod.Foo{i}):\n pass" for i in range(50)]
source = "\n".join(imports + classes)
tree = ast.parse(source)
for i in range(50):
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Bar{i}")
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 371μs -> 215μs (72.5% faster)
def test_large_scale_no_imports_needed():
# Many classes inheriting from 'object', no imports needed
source = "\n".join([f"class Foo{i}(object):\n pass" for i in range(100)])
tree = ast.parse(source)
for i in range(100):
class_node = next(node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == f"Foo{i}")
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 1.28ms -> 809μs (57.6% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.import ast # used to create AST nodes for testing
# imports
from codeflash.context.code_context_extractor import _extract_imports_for_class
# Basic Test Cases
def test_single_base_class_with_simple_import():
"""Test extracting import for a class with one base class from simple import."""
# Create source code with import and class definition
source = "import BaseClass\n\nclass MyClass(BaseClass):\n pass"
# Parse the source into an AST
tree = ast.parse(source)
# Get the class node (second item in body after import)
class_node = tree.body[1]
# Call function and verify it extracts the import
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.21μs -> 4.14μs (22.4% slower)
def test_single_base_class_with_from_import():
"""Test extracting import for a class with base class from 'from...import' statement."""
# Create source with from-import statement
source = "from module import BaseClass\n\nclass MyClass(BaseClass):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Verify the from-import is extracted
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.00μs -> 4.27μs (29.8% slower)
def test_multiple_base_classes_different_imports():
"""Test extracting multiple imports for class with multiple base classes."""
# Create source with two imports and class inheriting from both
source = "import Base1\nfrom module import Base2\n\nclass MyClass(Base1, Base2):\n pass"
tree = ast.parse(source)
class_node = tree.body[2]
# Both imports should be extracted
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.94μs -> 5.37μs (26.5% slower)
def test_no_base_classes():
"""Test that no imports are extracted when class has no base classes."""
# Create source with import but class doesn't use it
source = "import SomeModule\n\nclass MyClass:\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Should return empty string since no base classes
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.40μs -> 978ns (146% faster)
def test_unused_imports_not_extracted():
"""Test that imports not used by base classes are not extracted."""
# Create source with multiple imports but only one used
source = "import Unused1\nimport Unused2\nfrom module import BaseClass\n\nclass MyClass(BaseClass):\n pass"
tree = ast.parse(source)
class_node = tree.body[3]
# Only the used import should be extracted
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.89μs -> 4.86μs (20.0% slower)
# Edge Test Cases
def test_attribute_base_class():
"""Test extracting import for base class specified as attribute (e.g., abc.ABC)."""
# Create source with module import and class using module.Class syntax
source = "import abc\n\nclass MyClass(abc.ABC):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Should extract the module import
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.20μs -> 4.36μs (26.6% slower)
def test_import_with_alias():
"""Test extracting import when base class is imported with alias."""
# Create source with aliased import
source = "import original_module as om\n\nclass MyClass(om):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Should extract the aliased import
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.94μs -> 3.91μs (24.7% slower)
def test_from_import_with_alias():
"""Test extracting from-import when base class has alias."""
# Create source with aliased from-import
source = "from module import OriginalClass as OC\n\nclass MyClass(OC):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Should extract the aliased from-import
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.08μs -> 4.10μs (25.0% slower)
def test_empty_module():
"""Test handling of empty module with no imports."""
# Create minimal source with just a class
source = "class MyClass:\n pass"
tree = ast.parse(source)
class_node = tree.body[0]
# Should return empty string
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 1.89μs -> 1.03μs (82.8% faster)
def test_base_class_without_import():
"""Test class with base class that has no corresponding import."""
# Create source where base class is not imported (might be builtin or error)
source = "class MyClass(SomeClass):\n pass"
tree = ast.parse(source)
class_node = tree.body[0]
# Should return empty string since no matching import
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.10μs -> 2.28μs (7.94% slower)
def test_dotted_import_name():
"""Test import with dotted name (e.g., import os.path)."""
# Create source with dotted import
source = "import os.path\n\nclass MyClass(os):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Should extract the import, matching on first part of dotted name
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.21μs -> 4.59μs (29.9% slower)
def test_multiple_names_in_single_import():
"""Test import statement with multiple names where only one is used."""
# Create source with comma-separated imports
source = "from module import Class1, Class2, Class3\n\nclass MyClass(Class2):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Should extract the entire import line
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.27μs -> 4.38μs (25.3% slower)
def test_nested_attribute_base_class():
"""Test base class with nested attributes (e.g., module.submodule.Class)."""
# Create source with import and nested attribute usage
source = "import collections.abc\n\nclass MyClass(collections.abc):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Should extract import matching the first part of attribute chain
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.31μs -> 4.65μs (28.7% slower)
def test_multiple_classes_same_import():
"""Test that same import is extracted for different classes using it."""
# Create source with one import and two classes using it
source = "from module import Base\n\nclass Class1(Base):\n pass\n\nclass Class2(Base):\n pass"
tree = ast.parse(source)
# Test first class
class_node1 = tree.body[1]
codeflash_output = _extract_imports_for_class(tree, class_node1, source)
result1 = codeflash_output # 3.17μs -> 4.37μs (27.4% slower)
# Test second class
class_node2 = tree.body[2]
codeflash_output = _extract_imports_for_class(tree, class_node2, source)
result2 = codeflash_output # 1.44μs -> 1.92μs (25.0% slower)
def test_multiline_source_correct_line_extraction():
"""Test that correct line is extracted from multiline source."""
# Create source with multiple lines and imports at different positions
source = "# Comment\nimport Base1\n\nfrom module import Base2\n\nclass MyClass(Base2):\n pass"
tree = ast.parse(source)
class_node = tree.body[2] # Third statement (after two imports)
# Should extract the correct line
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 3.48μs -> 4.70μs (26.0% slower)
def test_import_star():
"""Test handling of 'from module import *' statement."""
# Create source with star import
source = "from module import *\n\nclass MyClass(SomeBase):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Star imports have special handling - check if extracted
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.77μs -> 2.63μs (5.16% faster)
def test_relative_import():
"""Test handling of relative imports (from . import or from .. import)."""
# Create source with relative import
source = "from .module import BaseClass\n\nclass MyClass(BaseClass):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Should extract relative import
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.80μs -> 4.11μs (31.9% slower)
def test_base_class_as_call_not_name():
"""Test that base classes specified as calls (not names) don't cause errors."""
# Create source where base class is a function call (e.g., Generic[T])
source = "from typing import Generic\n\nclass MyClass(Generic[int]):\n pass"
tree = ast.parse(source)
class_node = tree.body[1]
# Function calls in bases should be handled gracefully (not extracted)
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 2.56μs -> 1.23μs (109% faster)
# Large Scale Test Cases
def test_many_base_classes():
"""Test class with many base classes requiring many imports."""
# Create source with 50 different base classes
num_bases = 50
imports = [f"from module{i} import Base{i}" for i in range(num_bases)]
bases = ", ".join([f"Base{i}" for i in range(num_bases)])
source = "\n".join(imports) + f"\n\nclass MyClass({bases}):\n pass"
tree = ast.parse(source)
class_node = tree.body[num_bases] # After all imports
# All imports should be extracted
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 25.1μs -> 29.6μs (15.2% slower)
result_lines = result.split("\n")
for i in range(num_bases):
pass
def test_many_imports_in_module():
"""Test module with many imports where only few are used."""
# Create source with 100 imports but only 3 used
num_imports = 100
imports = [f"from module{i} import Class{i}" for i in range(num_imports)]
source = "\n".join(imports) + "\n\nclass MyClass(Class10, Class50, Class99):\n pass"
tree = ast.parse(source)
class_node = tree.body[num_imports]
# Only 3 imports should be extracted
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 29.3μs -> 32.2μs (9.27% slower)
result_lines = result.split("\n")
def test_large_source_file():
"""Test extraction from large source file with many lines."""
# Create source with 500 lines including comments, blank lines, and code
lines = []
lines.append("import BaseClass")
for i in range(498):
lines.append(f"# Comment line {i}")
lines.append("class MyClass(BaseClass):\n pass")
source = "\n".join(lines)
tree = ast.parse(source)
# Find the class node (should be last statement)
class_node = [node for node in tree.body if isinstance(node, ast.ClassDef)][0]
# Should correctly extract import from line 0
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 38.8μs -> 23.3μs (66.9% faster)
def test_complex_inheritance_hierarchy():
"""Test class with complex mix of import types and base classes."""
# Create source with various import styles
source = """import abc
import typing
from collections.abc import Mapping
from typing import Generic, TypeVar
from custom import CustomBase as CB
T = TypeVar('T')
class MyClass(abc.ABC, Mapping, CB):
pass"""
tree = ast.parse(source)
# Find the class node
class_node = [node for node in tree.body if isinstance(node, ast.ClassDef)][0]
# Should extract all relevant imports
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 5.91μs -> 7.53μs (21.5% slower)
def test_performance_with_many_unused_imports():
"""Test performance when module has many imports but class uses none."""
# Create source with 200 imports but class has no bases
num_imports = 200
imports = [f"import module{i}" for i in range(num_imports)]
source = "\n".join(imports) + "\n\nclass MyClass:\n pass"
tree = ast.parse(source)
class_node = tree.body[num_imports]
# Should quickly return empty string
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 50.7μs -> 1.44μs (3432% faster)
def test_mixed_import_styles_large_scale():
"""Test large module with mixed import styles (import, from-import, aliases)."""
# Create source with 100 mixed-style imports
imports = []
for i in range(100):
if i % 3 == 0:
imports.append(f"import module{i}")
elif i % 3 == 1:
imports.append(f"from package{i} import Class{i}")
else:
imports.append(f"from package{i} import Class{i} as C{i}")
# Use every 10th import as base class
bases = []
expected_imports = []
for i in range(0, 100, 10):
if i % 3 == 0:
bases.append(f"module{i}")
expected_imports.append(f"import module{i}")
elif i % 3 == 1:
bases.append(f"Class{i}")
expected_imports.append(f"from package{i} import Class{i}")
else:
bases.append(f"C{i}")
expected_imports.append(f"from package{i} import Class{i} as C{i}")
source = "\n".join(imports) + f"\n\nclass MyClass({', '.join(bases)}):\n pass"
tree = ast.parse(source)
class_node = tree.body[100]
# Should extract all 10 used imports
codeflash_output = _extract_imports_for_class(tree, class_node, source)
result = codeflash_output # 32.0μs -> 36.6μs (12.7% slower)
result_lines = result.split("\n")
for expected in expected_imports:
pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.⏪ Click to see Replay Tests
| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|---|---|---|---|
benchmarks/codeflash_replay_tests_fuodcj9h/test_tests_benchmarks_test_benchmark_code_extract_code_context__replay_test_0.py::test_codeflash_context_code_context_extractor__extract_imports_for_class_test_benchmark_extract |
786μs | 332μs | 136%✅ |
To test or edit this optimization locally git merge codeflash/optimize-pr1014-2026-01-07T21.53.54
Click to see suggested changes
| # Find imports that provide these names | |
| import_lines: list[str] = [] | |
| source_lines = module_source.split("\n") | |
| for node in module_tree.body: | |
| if isinstance(node, ast.Import): | |
| for alias in node.names: | |
| name = alias.asname if alias.asname else alias.name.split(".")[0] | |
| if name in needed_names: | |
| import_lines.append(source_lines[node.lineno - 1]) | |
| break | |
| elif isinstance(node, ast.ImportFrom): | |
| for alias in node.names: | |
| name = alias.asname if alias.asname else alias.name | |
| if name in needed_names: | |
| import_lines.append(source_lines[node.lineno - 1]) | |
| break | |
| if not needed_names: | |
| return "" | |
| # Find imports that provide these names | |
| import_line_numbers: list[int] = [] | |
| remaining_names = needed_names.copy() | |
| for node in module_tree.body: | |
| if not remaining_names: | |
| break | |
| if isinstance(node, ast.Import): | |
| for alias in node.names: | |
| name = alias.asname if alias.asname else alias.name.split(".")[0] | |
| if name in remaining_names: | |
| import_line_numbers.append(node.lineno - 1) | |
| remaining_names.discard(name) | |
| break | |
| elif isinstance(node, ast.ImportFrom): | |
| for alias in node.names: | |
| name = alias.asname if alias.asname else alias.name | |
| if name in remaining_names: | |
| import_line_numbers.append(node.lineno - 1) | |
| remaining_names.discard(name) | |
| break | |
| if not import_line_numbers: | |
| return "" | |
| source_lines = module_source.split("\n") | |
| import_lines: list[str] = [source_lines[line_num] for line_num in import_line_numbers] | |
⚡️ Codeflash found optimizations for this PR📄 1,809% (18.09x) speedup for
|
Summary
Element(text="...")whenElementis abstract)Problem
When generating tests for functions like
will_fitin unstructured, the LLM only saw:Without the actual
Elementclass definition, the LLM incorrectly guessed:Element(text="...")fails)Elementis abstract and can't be instantiated directlySolution
Now the testgen context includes the actual class definitions:
Changes
get_imported_class_definitions()function tocode_context_extractor.pyget_code_optimization_context()to include extracted classesTest plan
uv run codeflash --file unstructured/chunking/base.py --function will_fit