Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ dependencies = [
"hatch-vcs>=0.4.0",
"hatchling>=1.25.0",
"pyinstrument>=5.0.0",
"pip>=24.3.1", # This is needed for some NPM/YARN/PNPM post-install scripts to work!
"pip>=24.3.1", # This is needed for some NPM/YARN/PNPM post-install scripts to work!
"rich-click>=1.8.5",
"python-dotenv>=1.0.1",
"giturlparse",
Expand All @@ -70,6 +70,7 @@ dependencies = [
"datasets",
"colorlog>=6.9.0",
"codegen-sdk-pink>=0.1.0",
"scikit-learn>=1.7.2",
]

# renovate: datasource=python-version depName=python
Expand Down
126 changes: 126 additions & 0 deletions src/graph_sitter/extensions/index/scikit_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""File-level semantic code search index using scikit-learn."""

import pickle
from pathlib import Path
from typing import Any, override

from sklearn.feature_extraction.text import TfidfVectorizer

from graph_sitter.core.codebase import Codebase
from graph_sitter.extensions.index.code_index import CodeIndex


class ScikitCodeIndex(CodeIndex):
"""Local code index using TF-IDF vectorization for semantic search.

Chis CodeIndex implementation builds a local vector database with scikit, not requiring openai api access.
"""

def __init__(self, codebase: Codebase, vectorizer: TfidfVectorizer | None = None) -> None:
super().__init__(codebase)
if vectorizer:
self.vectorizer = vectorizer
else:
self.vectorizer: TfidfVectorizer = TfidfVectorizer(stop_words="english", max_features=5000, ngram_range=(1, 2))
self._fitted: bool = False

@property
@override
def save_file_name(self) -> str:
return "local_index_{commit}.pkl"

@override
def _get_embeddings(self, items: list[Any]) -> list[list[float]]:
"""Get TF-IDF embeddings for content."""
if not self._fitted:
all_items = [content for _, content in self._get_items_to_index()]
if all_items:
_ = self.vectorizer.fit(all_items)
self._fitted = True

if not items:
return []

# Extract content strings from items if they are tuples
content_items = []
for item in items:
if isinstance(item, tuple) and len(item) >= 2:
content_items.append(item[1]) # Get content from tuple
elif isinstance(item, str):
content_items.append(item)
else:
content_items.append(str(item))

vectors = self.vectorizer.transform(content_items)
return vectors.toarray().tolist() # pyright: ignore [reportAttributeAccessIssue]

@override
def _get_items_to_index(self) -> list[tuple[Any, str]]:
"""Get all files and their content."""
items = []
for file in self.codebase.files():
try:
content = file.content
if content.strip(): # Only index non-empty files
items.append((file, content))
# pylint: disable-next=broad-exception-caught, can't do a lot anyways here
except Exception:
continue # Skip files that can't be read
return items

@override
def _get_changed_items(self) -> set[Any]:
"""Get files that have changed since last commit."""
if not self.commit_hash:
return set()

changed = set()
try:
current_commit = self._get_current_commit()
if current_commit != self.commit_hash:
# For simplicity, consider all files as potentially changed
changed = set(self.codebase.files())
# pylint: disable-next=broad-exception-caught, can't do a lot anyways here
except Exception:
pass

return changed

@override
def _save_index(self, path: Path) -> None:
"""Save index data to disk."""
data = {
"E": self.E,
"items": self.items,
"commit_hash": self.commit_hash,
"vectorizer": self.vectorizer,
"fitted": self._fitted,
}
with open(path, "wb") as f:
pickle.dump(data, f)

@override
def _load_index(self, path: Path) -> None:
"""Load index data from disk."""
with open(path, "rb") as f:
data = pickle.load(f)

self.E = data["E"]
self.items = data["items"]
self.commit_hash = data["commit_hash"]
self.vectorizer = data["vectorizer"]
self._fitted = data["fitted"]

@override
def similarity_search(self, query: str, k: int = 5) -> list[tuple[Any, float]]:
"""Find the k most similar files to a query."""
raw_results = self._similarity_search_raw(query, k)

results = []
for item_str, score in raw_results:
for file in self.codebase.files():
if str(file) == item_str:
results.append((file, score))
break

return results
137 changes: 137 additions & 0 deletions tests/integration/test_scikit_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from pathlib import Path

import numpy as np
import pytest

from graph_sitter.codebase.factory.get_session import get_codebase_session
from graph_sitter.extensions.index.scikit_index import ScikitCodeIndex


def test_scikit_index_lifecycle(tmpdir) -> None:
# language=python
content1 = """
def hello():
print("Hello, world!")

def goodbye():
print("Goodbye, world!")
"""

# language=python
content2 = """
def greet(name: str):
print(f"Hi {name}!")
"""

with get_codebase_session(tmpdir=tmpdir, files={"greetings.py": content1, "hello.py": content2}) as codebase:
# Test construction and initial indexing
index = ScikitCodeIndex(codebase=codebase)
index.create()

# Verify initial state
assert index.E is not None
assert index.items is not None
assert len(index.items) == 2 # Both files should be indexed
assert index.commit_hash is not None

# Test similarity search
results = index.similarity_search("greeting someone", k=2)
assert len(results) == 2
# The greet function should be most relevant to greeting
assert any("hello.py" in file.filepath for file, _ in results)

# Test saving
save_dir = Path(tmpdir) / ".codegen"
index.save()
assert save_dir.exists()
saved_files = list(save_dir.glob("file_index_*.pkl"))
assert len(saved_files) == 1

# Test loading
new_index = FileIndex(codebase)
new_index.load(saved_files[0])
assert np.array_equal(index.E, new_index.E)
assert np.array_equal(index.items, new_index.items)
assert index.commit_hash == new_index.commit_hash

# Test updating after file changes
# Add a new function to greetings.py
greetings_file = codebase.get_file("greetings.py")
new_content = greetings_file.content + "\n\ndef welcome():\n print('Welcome!')\n"
greetings_file.edit(new_content)

# Update the index
index.update()

# Verify the update
assert len(index.items) >= 2 # Should have at least the original files

# Search for the new content
results = index.similarity_search("welcome message", k=2)
assert len(results) == 2
# The updated greetings.py should be relevant now
assert any("greetings.py" in file.filepath for file, _ in results)


def test_file_index_empty_file(tmpdir) -> None:
"""Test that the file index handles empty files gracefully."""
with get_codebase_session(tmpdir=tmpdir, files={"empty.py": ""}) as codebase:
index = FileIndex(codebase)
index.create()
assert len(index.items) == 0 # Empty file should be skipped


def test_file_index_large_file(tmpdir) -> None:
"""Test that the file index handles files larger than the token limit."""
# Create a large file by repeating a simple function many times
large_content = "def f():\n print('test')\n\n" * 10000

with get_codebase_session(tmpdir=tmpdir, files={"large.py": large_content}) as codebase:
index = FileIndex(codebase)
index.create()

# Should have multiple chunks for the large file
assert len([item for item in index.items if "large.py" in item]) > 1

# Test searching in large file
results = index.similarity_search("function that prints test", k=1)
assert len(results) == 1
assert "large.py" in results[0][0].filepath


def test_file_index_invalid_operations(tmpdir) -> None:
"""Test that the file index properly handles invalid operations."""
with get_codebase_session(tmpdir=tmpdir, files={"test.py": "print('test')"}) as codebase:
index = FileIndex(codebase)

# Test searching before creating index
with pytest.raises(ValueError, match="No embeddings available"):
index.similarity_search("test")

# Test saving before creating index
with pytest.raises(ValueError, match="No embeddings to save"):
index.save()

# Test updating before creating index
with pytest.raises(ValueError, match="No index to update"):
index.update()

# Test loading from non-existent path
with pytest.raises(FileNotFoundError):
index.load("nonexistent.pkl")


def test_file_index_binary_files(tmpdir) -> None:
"""Test that the file index properly handles binary files."""
# Create a binary file
binary_content = bytes([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) # PNG header
binary_path = Path(tmpdir) / "test.png"
binary_path.write_bytes(binary_content)

with get_codebase_session(tmpdir=tmpdir, files={"test.py": "print('test')", "test.png": binary_content}) as codebase:
index = FileIndex(codebase)
index.create()

# Should only index the Python file
assert len(index.items) == 1
assert all("test.py" in item for item in index.items)
Loading
Loading