Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit ab15716

Browse files
committed
Add tests for the openrouter provider
1 parent 920d882 commit ab15716

1 file changed

Lines changed: 98 additions & 0 deletions

File tree

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import json
2+
from unittest.mock import AsyncMock, MagicMock
3+
4+
import pytest
5+
from fastapi import HTTPException
6+
from fastapi.requests import Request
7+
8+
from codegate.config import DEFAULT_PROVIDER_URLS
9+
from codegate.pipeline.factory import PipelineFactory
10+
from codegate.providers.openrouter.provider import OpenRouterProvider
11+
12+
13+
@pytest.fixture
14+
def mock_factory():
15+
return MagicMock(spec=PipelineFactory)
16+
17+
18+
@pytest.fixture
19+
def provider(mock_factory):
20+
return OpenRouterProvider(mock_factory)
21+
22+
23+
def test_get_base_url(provider):
24+
"""Test that _get_base_url returns the correct OpenRouter API URL"""
25+
assert provider._get_base_url() == DEFAULT_PROVIDER_URLS["openrouter"]
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_model_prefix_added():
30+
"""Test that model name gets prefixed with openrouter/ when not already present"""
31+
mock_factory = MagicMock(spec=PipelineFactory)
32+
provider = OpenRouterProvider(mock_factory)
33+
provider.process_request = AsyncMock()
34+
35+
# Mock request
36+
mock_request = MagicMock(spec=Request)
37+
mock_request.body = AsyncMock(return_value=json.dumps({"model": "gpt-4"}).encode())
38+
mock_request.url.path = "/openrouter/chat/completions"
39+
mock_request.state.detected_client = "test-client"
40+
41+
# Get the route handler function
42+
route_handlers = [
43+
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
44+
]
45+
create_completion = route_handlers[0].endpoint
46+
47+
await create_completion(request=mock_request, authorization="Bearer test-token")
48+
49+
# Verify process_request was called with prefixed model
50+
call_args = provider.process_request.call_args[0]
51+
assert call_args[0]["model"] == "openrouter/gpt-4"
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_model_prefix_preserved():
56+
"""Test that model name is not modified when openrouter/ prefix is already present"""
57+
mock_factory = MagicMock(spec=PipelineFactory)
58+
provider = OpenRouterProvider(mock_factory)
59+
provider.process_request = AsyncMock()
60+
61+
# Mock request
62+
mock_request = MagicMock(spec=Request)
63+
mock_request.body = AsyncMock(return_value=json.dumps({"model": "openrouter/gpt-4"}).encode())
64+
mock_request.url.path = "/openrouter/chat/completions"
65+
mock_request.state.detected_client = "test-client"
66+
67+
# Get the route handler function
68+
route_handlers = [
69+
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
70+
]
71+
create_completion = route_handlers[0].endpoint
72+
73+
await create_completion(request=mock_request, authorization="Bearer test-token")
74+
75+
# Verify process_request was called with unchanged model name
76+
call_args = provider.process_request.call_args[0]
77+
assert call_args[0]["model"] == "openrouter/gpt-4"
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_invalid_auth_header():
82+
"""Test that invalid authorization header format raises HTTPException"""
83+
mock_factory = MagicMock(spec=PipelineFactory)
84+
provider = OpenRouterProvider(mock_factory)
85+
86+
mock_request = MagicMock(spec=Request)
87+
88+
# Get the route handler function
89+
route_handlers = [
90+
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
91+
]
92+
create_completion = route_handlers[0].endpoint
93+
94+
with pytest.raises(HTTPException) as exc_info:
95+
await create_completion(request=mock_request, authorization="InvalidToken")
96+
97+
assert exc_info.value.status_code == 401
98+
assert exc_info.value.detail == "Invalid authorization header"

0 commit comments

Comments
 (0)