Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit e6e34fe

Browse files
committed
Kick off provider endpoint CRUD structure
This structure will handle all the database operations and turn that into the right models. Note that for provider endpoints we already have a way of setting these via configuration, so this is taken into account to output some sample objects that users can leverage. We probably need to set some hardcoded IDs sometime in the near future so we can reference these. But for now, IDs for these are optional. I also added a `source` field which allows us to tell where these providers came from. Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com>
1 parent bc96c0e commit e6e34fe

4 files changed

Lines changed: 138 additions & 23 deletions

File tree

src/codegate/api/v1.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1+
from threading import Lock
12
from typing import List, Optional
23

34
import requests
45
import structlog
5-
from fastapi import APIRouter, HTTPException, Response
6+
from fastapi import APIRouter, Depends, HTTPException, Response
67
from fastapi.responses import StreamingResponse
78
from fastapi.routing import APIRoute
89
from pydantic import ValidationError
910

1011
from codegate import __version__
1112
from codegate.api import v1_models, v1_processing
1213
from codegate.db.connection import AlreadyExistsError, DbReader
14+
from codegate.providers import crud as provendcrud
1315
from codegate.workspaces import crud
1416

1517
logger = structlog.get_logger("codegate")
1618

19+
mtx = Lock()
20+
1721
v1 = APIRouter()
1822
wscrud = crud.WorkspaceCrud()
1923

@@ -26,26 +30,24 @@ def uniq_name(route: APIRoute):
2630

2731

2832
@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name)
29-
async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]:
33+
async def list_provider_endpoints(
34+
name: Optional[str] = None,
35+
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
36+
) -> List[v1_models.ProviderEndpoint]:
3037
"""List all provider endpoints."""
31-
# NOTE: This is a dummy implementation. In the future, we should have a proper
32-
# implementation that fetches the provider endpoints from the database.
33-
return [
34-
v1_models.ProviderEndpoint(
35-
id=1,
36-
name="dummy",
37-
description="Dummy provider endpoint",
38-
endpoint="http://example.com",
39-
provider_type=v1_models.ProviderType.openai,
40-
auth_type=v1_models.ProviderAuthType.none,
41-
)
42-
]
38+
try:
39+
return pcrud.list_endpoints()
40+
except Exception:
41+
raise HTTPException(status_code=500, detail="Internal server error")
4342

4443

4544
@v1.get(
4645
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
4746
)
48-
async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
47+
async def get_provider_endpoint(
48+
provider_id: int,
49+
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
50+
) -> v1_models.ProviderEndpoint:
4951
"""Get a provider endpoint by ID."""
5052
# NOTE: This is a dummy implementation. In the future, we should have a proper
5153
# implementation that fetches the provider endpoint from the database.
@@ -65,7 +67,10 @@ async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
6567
generate_unique_id_function=uniq_name,
6668
status_code=201,
6769
)
68-
async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint:
70+
async def add_provider_endpoint(
71+
request: v1_models.ProviderEndpoint,
72+
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
73+
) -> v1_models.ProviderEndpoint:
6974
"""Add a provider endpoint."""
7075
# NOTE: This is a dummy implementation. In the future, we should have a proper
7176
# implementation that adds the provider endpoint to the database.
@@ -76,7 +81,9 @@ async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_model
7681
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
7782
)
7883
async def update_provider_endpoint(
79-
provider_id: int, request: v1_models.ProviderEndpoint
84+
provider_id: int,
85+
request: v1_models.ProviderEndpoint,
86+
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
8087
) -> v1_models.ProviderEndpoint:
8188
"""Update a provider endpoint by ID."""
8289
# NOTE: This is a dummy implementation. In the future, we should have a proper
@@ -87,7 +94,10 @@ async def update_provider_endpoint(
8794
@v1.delete(
8895
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
8996
)
90-
async def delete_provider_endpoint(provider_id: int):
97+
async def delete_provider_endpoint(
98+
provider_id: int,
99+
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
100+
):
91101
"""Delete a provider endpoint by id."""
92102
# NOTE: This is a dummy implementation. In the future, we should have a proper
93103
# implementation that deletes the provider endpoint from the database.
@@ -99,7 +109,10 @@ async def delete_provider_endpoint(provider_id: int):
99109
tags=["Providers"],
100110
generate_unique_id_function=uniq_name,
101111
)
102-
async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]:
112+
async def list_models_by_provider(
113+
provider_name: str,
114+
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
115+
) -> List[v1_models.ModelByProvider]:
103116
"""List models by provider."""
104117
# NOTE: This is a dummy implementation. In the future, we should have a proper
105118
# implementation that fetches the models by provider from the database.
@@ -111,7 +124,9 @@ async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByP
111124
tags=["Providers"],
112125
generate_unique_id_function=uniq_name,
113126
)
114-
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
127+
async def list_all_models_for_all_providers(
128+
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
129+
) -> List[v1_models.ModelByProvider]:
115130
"""List all models for all providers."""
116131
# NOTE: This is a dummy implementation. In the future, we should have a proper
117132
# implementation that fetches all the models for all providers from the database.
@@ -394,7 +409,10 @@ async def delete_workspace_custom_instructions(workspace_name: str):
394409
tags=["Workspaces", "Muxes"],
395410
generate_unique_id_function=uniq_name,
396411
)
397-
async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
412+
async def get_workspace_muxes(
413+
workspace_name: str,
414+
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
415+
) -> List[v1_models.MuxRule]:
398416
"""Get the mux rules of a workspace.
399417
400418
The list is ordered in order of priority. That is, the first rule in the list
@@ -422,7 +440,11 @@ async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
422440
generate_unique_id_function=uniq_name,
423441
status_code=204,
424442
)
425-
async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]):
443+
async def set_workspace_muxes(
444+
workspace_name: str,
445+
request: List[v1_models.MuxRule],
446+
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
447+
):
426448
"""Set the mux rules of a workspace."""
427449
# TODO: This is a dummy implementation. In the future, we should have a proper
428450
# implementation that sets the mux rules in the database.

src/codegate/api/v1_models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ class ProviderType(str, Enum):
148148
openai = "openai"
149149
anthropic = "anthropic"
150150
vllm = "vllm"
151+
ollama = "ollama"
152+
lm_studio = "lm_studio"
151153

152154

153155
class ProviderAuthType(str, Enum):
@@ -163,19 +165,29 @@ class ProviderAuthType(str, Enum):
163165
api_key = "api_key"
164166

165167

168+
class ProviderEndpointSource(str, Enum):
169+
"""
170+
Represents the different sources of provider endpoints.
171+
"""
172+
173+
config = "config"
174+
db = "db"
175+
176+
166177
class ProviderEndpoint(pydantic.BaseModel):
167178
"""
168179
Represents a provider's endpoint configuration. This
169180
allows us to persist the configuration for each provider,
170181
so we can use this for muxing messages.
171182
"""
172183

173-
id: int
184+
id: Optional[int] = None
174185
name: str
175186
description: str = ""
176187
provider_type: ProviderType
177188
endpoint: str
178189
auth_type: ProviderAuthType
190+
source: ProviderEndpointSource
179191

180192

181193
class ModelByProvider(pydantic.BaseModel):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .crud import ProviderCrud
2+
3+
__all__ = ["ProviderCrud"]
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import List, Optional
2+
from urllib.parse import urlparse
3+
4+
import structlog
5+
from pydantic import ValidationError
6+
7+
from codegate.api import v1_models as apimodelsv1
8+
from codegate.config import Config
9+
from codegate.db.connection import DbReader, DbRecorder
10+
11+
logger = structlog.get_logger("codegate")
12+
13+
14+
class ProviderCrud:
15+
"""The CRUD operations for the provider endpoint references within
16+
Codegate.
17+
18+
This is meant to handle all the transformations in between the
19+
database and the API, as well as other sources of information. All
20+
operations should result in the API models being returned.
21+
"""
22+
23+
def __init__(self):
24+
self._db_reader = DbReader()
25+
self._db_writer = DbRecorder()
26+
config = Config.get_config()
27+
if config is None:
28+
logger.warning("OZZ: No configuration found.")
29+
provided_urls = {}
30+
else:
31+
logger.info("OZZ: Using configuration for provider URLs.")
32+
provided_urls = config.provider_urls
33+
34+
self._provider_urls = provided_urls
35+
36+
def list_endpoints(self) -> List[apimodelsv1.ProviderEndpoint]:
37+
"""List all the endpoints."""
38+
39+
endpoints = []
40+
for provider_name, provider_url in self._provider_urls.items():
41+
provend = self.__provider_endpoint_from_cfg(provider_name, provider_url)
42+
if provend is not None:
43+
endpoints.append(provend)
44+
45+
return endpoints
46+
47+
def __provider_endpoint_from_cfg(
48+
self, provider_name: str, provider_url: str
49+
) -> Optional[apimodelsv1.ProviderEndpoint]:
50+
"""Create a provider endpoint from the config entry."""
51+
52+
try:
53+
_ = urlparse(provider_url)
54+
except Exception:
55+
logger.warning(
56+
"Invalid provider URL", provider_name=provider_name, provider_url=provider_url
57+
)
58+
return None
59+
60+
try:
61+
return apimodelsv1.ProviderEndpoint(
62+
name=provider_name,
63+
endpoint=provider_url,
64+
descrption=("Endpoint for the {} provided via the CodeGate configuration.").format(
65+
provider_name
66+
),
67+
provider_type=provider_name,
68+
auth_type=apimodelsv1.ProviderAuthType.none,
69+
source=apimodelsv1.ProviderEndpointSource.config,
70+
)
71+
except ValidationError as err:
72+
logger.warning(
73+
"Invalid provider name",
74+
provider_name=provider_name,
75+
provider_url=provider_url,
76+
err=str(err),
77+
)
78+
return None

0 commit comments

Comments
 (0)