Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -679,26 +680,40 @@ def detect_frameworks_from_code(code: str) -> dict[str, str]:
except SyntaxError:
return frameworks

for node in ast.walk(tree):
# Faster explicit traversal - focus on import-related statements only.
# Use deque for BFS-like traversal to match ast.walk() order
queue = deque([tree])
framework_keys = ("torch", "tensorflow", "jax")

while queue:
node = queue.popleft()
# Only process Import and ImportFrom nodes for efficiency
if isinstance(node, ast.Import):
for alias in node.names:
module_name = alias.name.split(".")[0]
if module_name == "torch":
# Use asname if available, otherwise use the module name
frameworks["torch"] = alias.asname if alias.asname else module_name
elif module_name == "tensorflow":
frameworks["tensorflow"] = alias.asname if alias.asname else module_name
elif module_name == "jax":
frameworks["jax"] = alias.asname if alias.asname else module_name
if module_name in framework_keys:
# Always update to match original behavior (keeps last occurrence)
frameworks[module_name] = alias.asname if alias.asname else module_name
elif isinstance(node, ast.ImportFrom): # noqa: SIM102
if node.module:
module_name = node.module.split(".")[0]
if module_name == "torch" and "torch" not in frameworks:
frameworks["torch"] = module_name
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
frameworks["tensorflow"] = module_name
elif module_name == "jax" and "jax" not in frameworks:
frameworks["jax"] = module_name
if module_name in framework_keys and module_name not in frameworks:
frameworks[module_name] = module_name
# Only descend into bodies for Module, FunctionDef, ClassDef, etc.
# Skip subtrees that can't contain Import or ImportFrom nodes for speed.
child_nodes = []
if hasattr(node, "body"):
child_nodes.extend(node.body)
if hasattr(node, "orelse"):
child_nodes.extend(node.orelse)
if hasattr(node, "finalbody"):
child_nodes.extend(node.finalbody)
# Handle if/while/for except handlers
if hasattr(node, "handlers"):
for handler in node.handlers:
if hasattr(handler, "body"):
child_nodes.extend(handler.body)
queue.extend(child_nodes)

return frameworks

Expand Down
Loading