Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions code_review_graph/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,8 @@ def get_provider(
return None

# Default: local
if not _check_available():
return None
try:
return LocalEmbeddingProvider(model_name=model)
except ImportError:
Expand Down
26 changes: 22 additions & 4 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,34 @@ class TestGetProviderModel:
"""Tests for model parameter in get_provider()."""

@patch("code_review_graph.embeddings.LocalEmbeddingProvider")
def test_local_passes_model(self, mock_cls):
@patch("code_review_graph.embeddings._check_available", return_value=True)
def test_local_passes_model(self, _mock_available, mock_cls):
mock_cls.return_value = MagicMock()
get_provider(provider=None, model="custom/model")
mock_cls.assert_called_once_with(model_name="custom/model")

@patch("code_review_graph.embeddings.LocalEmbeddingProvider")
def test_local_default_passes_none(self, mock_cls):
@patch("code_review_graph.embeddings._check_available", return_value=True)
def test_local_default_passes_none(self, _mock_available, mock_cls):
mock_cls.return_value = MagicMock()
get_provider(provider=None, model=None)
mock_cls.assert_called_once_with(model_name=None)

@patch("code_review_graph.embeddings._check_available", return_value=False)
def test_local_unavailable_returns_none(self, _mock_available):
assert get_provider("local") is None

@patch("code_review_graph.embeddings._check_available", return_value=False)
def test_embedding_store_unavailable_without_local_dependency(
self, _mock_available, tmp_path,
):
db = tmp_path / "embeddings.db"
store = EmbeddingStore(db, provider="local")
try:
assert store.available is False
finally:
store.close()


class TestCloudProviderWarning:
"""Tests for the stderr warning before cloud provider use (#174)."""
Expand Down Expand Up @@ -237,8 +254,9 @@ def test_local_provider_never_warns(self, capsys):
with patch(
"code_review_graph.embeddings.LocalEmbeddingProvider",
) as mock_cls:
mock_cls.return_value = MagicMock()
get_provider(provider=None)
with patch("code_review_graph.embeddings._check_available", return_value=True):
mock_cls.return_value = MagicMock()
get_provider(provider=None)
captured = capsys.readouterr()
assert "cloud" not in captured.err.lower()

Expand Down