Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions code_review_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,12 @@ def upsert_edge(self, edge: EdgeInfo) -> int:
now = time.time()
extra = json.dumps(edge.extra) if edge.extra else "{}"

# Check for existing edge
# Check for existing edge (include line so multiple call sites are preserved)
existing = self._conn.execute(
"""SELECT id FROM edges
WHERE kind=? AND source_qualified=? AND target_qualified=? AND file_path=?""",
(edge.kind, edge.source, edge.target, edge.file_path),
WHERE kind=? AND source_qualified=? AND target_qualified=?
AND file_path=? AND line=?""",
(edge.kind, edge.source, edge.target, edge.file_path, edge.line),
).fetchone()

if existing:
Expand Down
203 changes: 201 additions & 2 deletions code_review_graph/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class CodeParser:

def __init__(self) -> None:
self._parsers: dict[str, object] = {}
self._module_file_cache: dict[str, Optional[str]] = {}

def _get_parser(self, language: str): # type: ignore[arg-type]
if language not in self._parsers:
Expand Down Expand Up @@ -242,9 +243,15 @@ def parse_bytes(self, path: Path, source: bytes) -> tuple[list[NodeInfo], list[E
language=language,
))

# Pre-scan for import mappings and defined names
import_map, defined_names = self._collect_file_scope(
tree.root_node, language, source,
)

# Walk the tree
self._extract_from_tree(
tree.root_node, source, language, file_path_str, nodes, edges
tree.root_node, source, language, file_path_str, nodes, edges,
import_map=import_map, defined_names=defined_names,
)

return nodes, edges
Expand All @@ -259,6 +266,8 @@ def _extract_from_tree(
edges: list[EdgeInfo],
enclosing_class: Optional[str] = None,
enclosing_func: Optional[str] = None,
import_map: Optional[dict[str, str]] = None,
defined_names: Optional[set[str]] = None,
) -> None:
"""Recursively walk the AST and extract nodes/edges."""
class_types = set(_CLASS_TYPES.get(language, []))
Expand Down Expand Up @@ -308,6 +317,7 @@ def _extract_from_tree(
self._extract_from_tree(
child, source, language, file_path, nodes, edges,
enclosing_class=name, enclosing_func=None,
import_map=import_map, defined_names=defined_names,
)
continue

Expand Down Expand Up @@ -353,6 +363,7 @@ def _extract_from_tree(
self._extract_from_tree(
child, source, language, file_path, nodes, edges,
enclosing_class=enclosing_class, enclosing_func=name,
import_map=import_map, defined_names=defined_names,
)
continue

Expand All @@ -374,10 +385,14 @@ def _extract_from_tree(
call_name = self._get_call_name(child, language, source)
if call_name and enclosing_func:
caller = self._qualify(enclosing_func, file_path, enclosing_class)
target = self._resolve_call_target(
call_name, file_path, language,
import_map or {}, defined_names or set(),
)
edges.append(EdgeInfo(
kind="CALLS",
source=caller,
target=call_name,
target=target,
file_path=file_path,
line=child.start_point[0] + 1,
))
Expand All @@ -386,7 +401,191 @@ def _extract_from_tree(
self._extract_from_tree(
child, source, language, file_path, nodes, edges,
enclosing_class=enclosing_class, enclosing_func=enclosing_func,
import_map=import_map, defined_names=defined_names,
)

def _collect_file_scope(
self, root, language: str, source: bytes,
) -> tuple[dict[str, str], set[str]]:
"""Pre-scan top-level AST to collect import mappings and defined names.

Returns:
(import_map, defined_names) where import_map maps imported names
to their source module/path, and defined_names is the set of
function/class names defined at file scope.
"""
import_map: dict[str, str] = {}
defined_names: set[str] = set()

class_types = set(_CLASS_TYPES.get(language, []))
func_types = set(_FUNCTION_TYPES.get(language, []))
import_types = set(_IMPORT_TYPES.get(language, []))

# Node types that wrap a class/function with decorators/annotations
decorator_wrappers = {"decorated_definition", "decorator"}

for child in root.children:
node_type = child.type

# Unwrap decorator wrappers to reach the inner definition
target = child
if node_type in decorator_wrappers:
for inner in child.children:
if inner.type in func_types or inner.type in class_types:
target = inner
break

target_type = target.type

# Collect defined function/class names
if target_type in func_types or target_type in class_types:
name = self._get_name(target, language,
"class" if target_type in class_types else "function")
if name:
defined_names.add(name)

# Collect import mappings: imported_name → module_path
if node_type in import_types:
self._collect_import_names(child, language, source, import_map)

return import_map, defined_names

def _collect_import_names(
self, node, language: str, source: bytes, import_map: dict[str, str],
) -> None:
"""Extract imported names and their source modules into import_map."""
if language == "python":
if node.type == "import_from_statement":
# from X.Y import A, B → {A: X.Y, B: X.Y}
module = None
seen_import_keyword = False
for child in node.children:
if child.type == "dotted_name" and not seen_import_keyword:
module = child.text.decode("utf-8", errors="replace")
elif child.type == "import":
seen_import_keyword = True
elif seen_import_keyword and module:
if child.type in ("identifier", "dotted_name"):
name = child.text.decode("utf-8", errors="replace")
import_map[name] = module
elif child.type == "aliased_import":
# from X import A as B → {B: X}
names = [
sub.text.decode("utf-8", errors="replace")
for sub in child.children
if sub.type in ("identifier", "dotted_name")
]
# Last name is the alias (local name)
if names:
import_map[names[-1]] = module

elif language in ("javascript", "typescript", "tsx"):
# import { A, B } from './path' → {A: ./path, B: ./path}
module = None
for child in node.children:
if child.type == "string":
module = child.text.decode("utf-8", errors="replace").strip("'\"")
if module:
for child in node.children:
if child.type == "import_clause":
self._collect_js_import_names(child, module, import_map)

def _collect_js_import_names(
self, clause_node, module: str, import_map: dict[str, str],
) -> None:
"""Walk JS/TS import_clause to extract named and default imports."""
for child in clause_node.children:
if child.type == "identifier":
# Default import
import_map[child.text.decode("utf-8", errors="replace")] = module
elif child.type == "named_imports":
for spec in child.children:
if spec.type == "import_specifier":
# Could be: name or name as alias
names = [
s.text.decode("utf-8", errors="replace")
for s in spec.children
if s.type in ("identifier", "property_identifier")
]
# Last identifier is the local name
if names:
import_map[names[-1]] = module

def _resolve_module_to_file(
self, module: str, file_path: str, language: str,
) -> Optional[str]:
"""Resolve a module/import path to an absolute file path.

Uses self._module_file_cache to avoid repeated filesystem lookups.
"""
cache_key = f"{language}:{module}"
if cache_key in self._module_file_cache:
return self._module_file_cache[cache_key]

resolved = self._do_resolve_module(module, file_path, language)
self._module_file_cache[cache_key] = resolved
return resolved

def _do_resolve_module(
self, module: str, file_path: str, language: str,
) -> Optional[str]:
"""Language-aware module-to-file resolution."""
caller_dir = Path(file_path).parent

if language == "python":
rel_path = module.replace(".", "/")
candidates = [rel_path + ".py", rel_path + "/__init__.py"]
# Walk up from caller's directory to find the module file
current = caller_dir
while True:
for candidate in candidates:
target = current / candidate
if target.is_file():
return str(target.resolve())
if current == current.parent:
break
current = current.parent

elif language in ("javascript", "typescript", "tsx"):
if module.startswith("."):
# Relative import — resolve from caller's directory
base = caller_dir / module
extensions = [".ts", ".tsx", ".js", ".jsx"]
# Try exact path first (might already have extension)
if base.is_file():
return str(base.resolve())
# Try with extensions
for ext in extensions:
target = base.with_suffix(ext)
if target.is_file():
return str(target.resolve())
# Try index file in directory
if base.is_dir():
for ext in extensions:
target = base / f"index{ext}"
if target.is_file():
return str(target.resolve())

return None

def _resolve_call_target(
self,
call_name: str,
file_path: str,
language: str,
import_map: dict[str, str],
defined_names: set[str],
) -> str:
"""Resolve a bare call name to a qualified target, with fallback."""
if call_name in defined_names:
return self._qualify(call_name, file_path, None)
if call_name in import_map:
resolved = self._resolve_module_to_file(
import_map[call_name], file_path, language,
)
if resolved:
return self._qualify(call_name, resolved, None)
return call_name

def _qualify(self, name: str, file_path: str, enclosing_class: Optional[str]) -> str:
"""Create a qualified name: file_path::ClassName.name or file_path::name."""
Expand Down
8 changes: 8 additions & 0 deletions tests/fixtures/caller_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Fixture that imports and calls functions from sample_python."""

from sample_python import create_auth_service


def setup_and_run():
service = create_auth_service()
return service
13 changes: 13 additions & 0 deletions tests/fixtures/multi_call_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Fixture with multiple calls to the same function from one caller."""


async def _internal_request(url: str, data: bytes) -> dict:
return {"url": url}


async def process_document(content: bytes) -> str:
"""Calls _internal_request twice on different lines."""
first = await _internal_request("http://localhost/fast", content)
text = first.get("body", "")
second = await _internal_request("http://localhost/slow", content)
return text or second.get("body", "")
12 changes: 12 additions & 0 deletions tests/fixtures/sample_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,15 @@ def process_request(service: AuthService, token: str) -> dict:
if service.authenticate(token):
return {"status": "ok"}
return {"status": "denied"}


def _log_action(func):
"""Simple decorator."""
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper


@_log_action
def guarded_process(service: AuthService, token: str) -> dict:
return process_request(service, token)
19 changes: 19 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,25 @@ def test_impact_radius(self):
impacted_qns = {n.qualified_name for n in result["impacted_nodes"]}
assert "/b.py::func_b" in impacted_qns or "/b.py" in impacted_qns

def test_upsert_edge_preserves_multiple_call_sites(self):
"""Multiple CALLS edges to the same target from the same source on different lines."""
edge1 = EdgeInfo(
kind="CALLS", source="/test/file.py::caller",
target="/test/file.py::helper", file_path="/test/file.py", line=10,
)
edge2 = EdgeInfo(
kind="CALLS", source="/test/file.py::caller",
target="/test/file.py::helper", file_path="/test/file.py", line=20,
)
self.store.upsert_edge(edge1)
self.store.upsert_edge(edge2)
self.store.commit()

edges = self.store.get_edges_by_source("/test/file.py::caller")
assert len(edges) == 2
lines = {e.line for e in edges}
assert lines == {10, 20}

def test_metadata(self):
self.store.set_metadata("test_key", "test_value")
assert self.store.get_metadata("test_key") == "test_value"
Expand Down
52 changes: 52 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,58 @@ def test_parse_test_file(self):
assert "test_authenticate_valid" in test_names
assert "test_process_request_ok" in test_names

def test_calls_edge_same_file_resolution(self):
"""Call targets defined in the same file should be qualified."""
nodes, edges = self.parser.parse_file(FIXTURES / "sample_python.py")
calls = [e for e in edges if e.kind == "CALLS"]
file_path = str(FIXTURES / "sample_python.py")

# create_auth_service() calls AuthService() — a class defined in the same file
auth_service_calls = [
e for e in calls if e.target == f"{file_path}::AuthService"
]
assert len(auth_service_calls) >= 1

def test_calls_edge_cross_file_resolution(self):
"""Call targets imported from another file should resolve to that file's qualified name."""
_, edges = self.parser.parse_file(FIXTURES / "caller_example.py")
calls = [e for e in edges if e.kind == "CALLS"]

sample_path = str((FIXTURES / "sample_python.py").resolve())
# setup_and_run() calls create_auth_service(), imported from sample_python
resolved_calls = [
e for e in calls if e.target == f"{sample_path}::create_auth_service"
]
assert len(resolved_calls) == 1

def test_unresolved_calls_stay_bare(self):
"""Method calls and unknown calls should remain as bare names."""
_, edges = self.parser.parse_file(FIXTURES / "sample_python.py")
calls = [e for e in edges if e.kind == "CALLS"]
# self._validate_token() is a method call — can't resolve the target file
bare_calls = [e for e in calls if e.target == "_validate_token"]
assert len(bare_calls) >= 1

def test_calls_edge_decorated_function_resolution(self):
"""Decorated functions should be in defined_names and resolvable as call targets."""
_, edges = self.parser.parse_file(FIXTURES / "sample_python.py")
calls = [e for e in edges if e.kind == "CALLS"]
file_path = str(FIXTURES / "sample_python.py")

# guarded_process() calls process_request() — both in the same file,
# but guarded_process is wrapped in a decorated_definition node
resolved = [e for e in calls if e.target == f"{file_path}::process_request"
and "guarded_process" in e.source]
assert len(resolved) == 1

def test_multiple_calls_to_same_function(self):
"""Multiple calls to the same function on different lines should each produce an edge."""
_, edges = self.parser.parse_file(FIXTURES / "multi_call_example.py")
calls = [e for e in edges if e.kind == "CALLS" and "_internal_request" in e.target]
assert len(calls) == 2
lines = {e.line for e in calls}
assert len(lines) == 2 # distinct line numbers

def test_parse_nonexistent_file(self):
nodes, edges = self.parser.parse_file(Path("/nonexistent/file.py"))
assert nodes == []
Expand Down