Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 56 additions & 18 deletions api/analyzers/source_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,52 @@ def supported_types(self) -> list[str]:
"""
return list(analyzers.keys())

def create_entity_hierarchy(self, entity: Entity, file: File, analyzer: AbstractAnalyzer, graph: Graph):
def create_entity_hierarchy(self, entity: Entity, file: File, analyzer: AbstractAnalyzer,
pending_entities: list, pending_rels: list):
types = analyzer.get_entity_types()
stack = list(entity.node.children)
while stack:
node = stack.pop()
if node.type in types:
child = Entity(node)
child.id = graph.add_entity(analyzer.get_entity_label(node), analyzer.get_entity_name(node), analyzer.get_entity_docstring(node), str(file.path), node.start_point.row, node.end_point.row, {})
pending_entities.append((
child, analyzer.get_entity_label(node),
analyzer.get_entity_name(node),
analyzer.get_entity_docstring(node),
str(file.path), node.start_point.row,
node.end_point.row, {}
))
if not analyzer.is_dependency(str(file.path)):
analyzer.add_symbols(child)
file.add_entity(child)
entity.add_child(child)
graph.connect_entities("DEFINES", entity.id, child.id)
self.create_entity_hierarchy(child, file, analyzer, graph)
pending_rels.append(("DEFINES", entity, child))
self.create_entity_hierarchy(child, file, analyzer,
pending_entities, pending_rels)
else:
stack.extend(node.children)

def create_hierarchy(self, file: File, analyzer: AbstractAnalyzer, graph: Graph):
def create_hierarchy(self, file: File, analyzer: AbstractAnalyzer,
pending_entities: list, pending_rels: list):
types = analyzer.get_entity_types()
stack = [file.tree.root_node]
while stack:
node = stack.pop()
if node.type in types:
entity = Entity(node)
entity.id = graph.add_entity(analyzer.get_entity_label(node), analyzer.get_entity_name(node), analyzer.get_entity_docstring(node), str(file.path), node.start_point.row, node.end_point.row, {})
pending_entities.append((
entity, analyzer.get_entity_label(node),
analyzer.get_entity_name(node),
analyzer.get_entity_docstring(node),
str(file.path), node.start_point.row,
node.end_point.row, {}
))
if not analyzer.is_dependency(str(file.path)):
analyzer.add_symbols(entity)
file.add_entity(entity)
graph.connect_entities("DEFINES", file.id, entity.id)
self.create_entity_hierarchy(entity, file, analyzer, graph)
pending_rels.append(("DEFINES", file, entity))
self.create_entity_hierarchy(entity, file, analyzer,
pending_entities, pending_rels)
else:
stack.extend(node.children)

Expand All @@ -87,6 +103,11 @@ def first_pass(self, path: Path, files: list[Path], ignore: list[str], graph: Gr
for ext in set([file.suffix for file in files if file.suffix in supoorted_types]):
analyzers[ext].add_dependencies(path, files)

# Phase 1: Parse files and build in-memory hierarchy
pending_files = []
pending_entities = []
pending_rels = []

files_len = len(files)
for i, file_path in enumerate(files):
# Skip none supported files
Expand All @@ -95,7 +116,7 @@ def first_pass(self, path: Path, files: list[Path], ignore: list[str], graph: Gr
continue

# Skip ignored files
if any([i in str(file_path) for i in ignore]):
if any(ig in str(file_path) for ig in ignore):
logging.info(f"Skipping ignored file {file_path}")
continue

Expand All @@ -110,10 +131,17 @@ def first_pass(self, path: Path, files: list[Path], ignore: list[str], graph: Gr
# Create file entity
file = File(file_path, tree)
self.files[file_path] = file
pending_files.append(file)

# Walk through the AST and collect entities/relationships
self.create_hierarchy(file, analyzer, pending_entities, pending_rels)

# Walk thought the AST
graph.add_file(file)
self.create_hierarchy(file, analyzer, graph)
# Phase 2: Batch insert files, entities, and relationships
graph.add_files_batch(pending_files)
graph.add_entities_batch(pending_entities)
graph.connect_entities_batch([
(rel, src.id, dest.id, {}) for rel, src, dest in pending_rels
])

def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None:
"""
Expand Down Expand Up @@ -144,8 +172,11 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None:
else:
lsps[".cs"] = NullLanguageServer()
with lsps[".java"].start_server(), lsps[".py"].start_server(), lsps[".cs"].start_server():
pending_rels = []
files_len = len(self.files)
for i, file_path in enumerate(files):
if file_path not in self.files:
continue
file = self.files[file_path]
logging.info(f'Processing file ({i + 1}/{files_len}): {file_path}')
for _, entity in file.entities.items():
Expand All @@ -155,18 +186,25 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None:
if len(symbol.resolved_symbol) == 0:
continue
resolved_symbol = next(iter(symbol.resolved_symbol))
rel = None
props = {}
if key == "base_class":
graph.connect_entities("EXTENDS", entity.id, resolved_symbol.id)
rel = "EXTENDS"
elif key == "implement_interface":
graph.connect_entities("IMPLEMENTS", entity.id, resolved_symbol.id)
rel = "IMPLEMENTS"
elif key == "extend_interface":
graph.connect_entities("EXTENDS", entity.id, resolved_symbol.id)
rel = "EXTENDS"
elif key == "call":
graph.connect_entities("CALLS", entity.id, resolved_symbol.id, {"line": symbol.symbol.start_point.row, "text": symbol.symbol.text.decode("utf-8")})
rel = "CALLS"
props = {"line": symbol.symbol.start_point.row, "text": symbol.symbol.text.decode("utf-8")}
elif key == "return_type":
graph.connect_entities("RETURNS", entity.id, resolved_symbol.id)
rel = "RETURNS"
elif key == "parameters":
graph.connect_entities("PARAMETERS", entity.id, resolved_symbol.id)
rel = "PARAMETERS"
if rel:
pending_rels.append((rel, entity.id, resolved_symbol.id, props))

graph.connect_entities_batch(pending_rels)

def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None:
self.first_pass(path, files, [], graph)
Expand Down
114 changes: 114 additions & 0 deletions api/graph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import os
import re
import time
from collections import defaultdict
from .entities import *
from typing import Optional
from falkordb import FalkorDB, Path, Node, QueryResult
from falkordb.asyncio import FalkorDB as AsyncFalkorDB

# Maximum items per UNWIND batch to avoid overwhelming FalkorDB/Redis
BATCH_SIZE = 500

# Regex to validate graph labels/relation types (alphanumeric + underscore only)
_VALID_LABEL_RE = re.compile(r'^[A-Za-z_][A-Za-z0-9_]*$')

# Configure the logger
import logging
logging.basicConfig(level=logging.DEBUG,
Expand Down Expand Up @@ -248,6 +256,9 @@ def add_entity(self, label: str, name: str, doc: str, path: str, src_start: int,
Args:
"""

if not _VALID_LABEL_RE.match(label):
raise ValueError(f"Invalid entity label: {label!r}")

q = f"""MERGE (c:{label}:Searchable {{name: $name, path: $path, src_start: $src_start,
src_end: $src_end}})
SET c.doc = $doc
Expand All @@ -267,6 +278,47 @@ def add_entity(self, label: str, name: str, doc: str, path: str, src_start: int,
node = res.result_set[0][0]
return node.id

def add_entities_batch(self, entities_data: list) -> None:
"""
Batch add entity nodes to the graph database using UNWIND.
Groups by label, then processes in chunks of BATCH_SIZE.

Args:
entities_data: list of tuples
(entity_obj, label, name, doc, path, src_start, src_end, props)
entity_obj.id will be set after insertion.
"""

if not entities_data:
return

by_label = defaultdict(list)
for item in entities_data:
by_label[item[1]].append(item)

for label, group in by_label.items():
if not _VALID_LABEL_RE.match(label):
raise ValueError(f"Invalid entity label: {label!r}")

q = f"""UNWIND $entities AS e
MERGE (c:{label}:Searchable {{name: e['name'], path: e['path'],
src_start: e['src_start'],
src_end: e['src_end']}})
SET c.doc = e['doc']
SET c += e['props']
RETURN c"""
Comment on lines +303 to +309
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_entities_batch interpolates label directly into the Cypher string. Since labels can’t be parameterized, this should defensively validate label (e.g., allowlist known entity labels or enforce a strict [A-Za-z0-9_]+ regex) to avoid Cypher injection if this method is ever called with untrusted input.

Copilot uses AI. Check for mistakes.

for start in range(0, len(group), BATCH_SIZE):
chunk = group[start:start + BATCH_SIZE]
data = [{
'name': item[2], 'doc': item[3], 'path': item[4],
'src_start': item[5], 'src_end': item[6], 'props': item[7]
} for item in chunk]

res = self._query(q, {'entities': data})
for j, item in enumerate(chunk):
item[0].id = res.result_set[j][0].id

def get_class_by_name(self, class_name: str) -> Optional[Node]:
q = "MATCH (c:Class) WHERE c.name = $name RETURN c LIMIT 1"
res = self._query(q, {'name': class_name}).result_set
Expand Down Expand Up @@ -406,6 +458,30 @@ def add_file(self, file: File) -> None:
node = res.result_set[0][0]
file.id = node.id

def add_files_batch(self, files: list[File]) -> None:
"""
Batch add file nodes to the graph database using UNWIND.
Processes in chunks of BATCH_SIZE to avoid oversized queries.

Args:
files: list of File objects. Each file.id will be set after insertion.
"""

if not files:
return

q = """UNWIND $files AS fd
MERGE (f:File:Searchable {path: fd['path'], name: fd['name'], ext: fd['ext']})
RETURN f"""

for start in range(0, len(files), BATCH_SIZE):
chunk = files[start:start + BATCH_SIZE]
file_data = [{'path': str(f.path), 'name': f.path.name, 'ext': f.path.suffix}
for f in chunk]
res = self._query(q, {'files': file_data})
for i, row in enumerate(res.result_set):
chunk[i].id = row[0].id

def delete_files(self, files: list[Path]) -> tuple[str, dict, list[int]]:
"""
Deletes file(s) from the graph in addition to any other entity
Expand Down Expand Up @@ -485,6 +561,44 @@ def connect_entities(self, relation: str, src_id: int, dest_id: int, properties:
params = {'src_id': src_id, 'dest_id': dest_id, "properties": properties}
self._query(q, params)

def connect_entities_batch(self, relationships: list[tuple[str, int, int, dict]]) -> None:
"""
Batch create relationships between entities using UNWIND.
Groups by relation type, then processes in chunks of BATCH_SIZE.

Args:
relationships: list of (relation, src_id, dest_id, properties)
"""

if not relationships:
return

by_relation = defaultdict(list)
for rel in relationships:
if rel[1] is None or rel[2] is None:
logging.warning(f"Skipping relationship {rel[0]} with None ID: src={rel[1]}, dest={rel[2]}")
continue
by_relation[rel[0]].append(rel)

for relation, group in by_relation.items():
if not _VALID_LABEL_RE.match(relation):
raise ValueError(f"Invalid relation type: {relation!r}")

q = f"""UNWIND $rels AS r
MATCH (src)
WHERE ID(src) = r['src_id']
MATCH (dest)
WHERE ID(dest) = r['dest_id']
MERGE (src)-[e:{relation}]->(dest)
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connect_entities_batch always uses MERGE (src)-[e:REL]->(dest) and then SET e += properties. For relationship types where multiple edges between the same pair are meaningful (e.g., multiple CALL sites between the same functions), this will collapse them into a single edge and later rows will overwrite earlier properties. Consider using CREATE for those relation types, or MERGE on a uniqueness key that includes the distinguishing properties (like line/pos).

Suggested change
MERGE (src)-[e:{relation}]->(dest)
CREATE (src)-[e:{relation}]->(dest)

Copilot uses AI. Check for mistakes.
SET e += r['properties']
RETURN e"""

for start in range(0, len(group), BATCH_SIZE):
chunk = group[start:start + BATCH_SIZE]
data = [{'src_id': r[1], 'dest_id': r[2], 'properties': r[3]}
for r in chunk]
self._query(q, {'rels': data})

def function_calls_function(self, caller_id: int, callee_id: int, pos: int) -> None:
"""
Establish a 'CALLS' relationship between two function nodes.
Expand Down
86 changes: 86 additions & 0 deletions tests/test_graph_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,91 @@ def test_function_calls_function(self):
res = self.g.query(query, params).result_set
self.assertTrue(res[0][0])

def test_add_files_batch(self):
files = [File(Path(f'/batch/file{i}.py'), None) for i in range(5)]
self.graph.add_files_batch(files)

for i, f in enumerate(files):
self.assertIsNotNone(f.id)
result = self.graph.get_file(f'/batch/file{i}.py', f'file{i}.py', '.py')
self.assertIsNotNone(result)
self.assertEqual(result.properties['name'], f'file{i}.py')

def test_add_files_batch_empty(self):
self.graph.add_files_batch([])

def test_add_entities_batch(self):
from unittest.mock import MagicMock

entities_data = []
for i in range(3):
mock_entity = MagicMock()
mock_entity.id = None
entities_data.append((
mock_entity, 'Function', f'func_{i}', f'doc {i}',
'/batch/path', i * 10, i * 10 + 5, {}
))

self.graph.add_entities_batch(entities_data)

for item in entities_data:
self.assertIsNotNone(item[0].id)

def test_connect_entities_batch(self):
file = File(Path('/batch/connect_test.py'), None)
self.graph.add_file(file)

func_a_id = self.graph.add_entity(
'Function', 'batch_a', '', '/batch/connect_test.py', 1, 5, {}
)
func_b_id = self.graph.add_entity(
'Function', 'batch_b', '', '/batch/connect_test.py', 6, 10, {}
)
func_c_id = self.graph.add_entity(
'Function', 'batch_c', '', '/batch/connect_test.py', 11, 15, {}
)

self.graph.connect_entities_batch([
("DEFINES", file.id, func_a_id, {}),
("DEFINES", file.id, func_b_id, {}),
("DEFINES", file.id, func_c_id, {}),
("CALLS", func_a_id, func_b_id, {"line": 3, "text": "batch_b()"}),
])

# Verify DEFINES relationships
q = """MATCH (f:File)-[:DEFINES]->(fn:Function)
WHERE ID(f) = $file_id
RETURN count(fn)"""
res = self.g.query(q, {'file_id': file.id}).result_set
self.assertEqual(res[0][0], 3)

# Verify CALLS relationship with properties
q = """MATCH (a:Function)-[c:CALLS]->(b:Function)
WHERE ID(a) = $a_id AND ID(b) = $b_id
RETURN c.line, c.text"""
res = self.g.query(q, {'a_id': func_a_id, 'b_id': func_b_id}).result_set
self.assertEqual(res[0][0], 3)
self.assertEqual(res[0][1], "batch_b()")

def test_connect_entities_batch_empty(self):
self.graph.connect_entities_batch([])

def test_batch_chunking(self):
"""Verify batches are correctly chunked when exceeding BATCH_SIZE."""
import api.graph as graph_module
original = graph_module.BATCH_SIZE
try:
graph_module.BATCH_SIZE = 3
files = [File(Path(f'/chunked/f{i}.py'), None) for i in range(7)]
self.graph.add_files_batch(files)
for f in files:
self.assertIsNotNone(f.id)
# Verify all 7 files are actually in the DB
for i in range(7):
result = self.graph.get_file(f'/chunked/f{i}.py', f'f{i}.py', '.py')
self.assertIsNotNone(result)
finally:
graph_module.BATCH_SIZE = original

if __name__ == '__main__':
unittest.main()
Loading