Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import requests
import structlog
from fastapi import APIRouter, HTTPException, Response
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRoute
from pydantic import ValidationError

from codegate import __version__
from codegate.api import v1_models, v1_processing
from codegate.db.connection import AlreadyExistsError, DbReader
from codegate.providers import crud as provendcrud
from codegate.workspaces import crud

logger = structlog.get_logger("codegate")
Expand All @@ -26,26 +27,24 @@ def uniq_name(route: APIRoute):


@v1.get("/provider-endpoints", tags=["Providers"], generate_unique_id_function=uniq_name)
async def list_provider_endpoints(name: Optional[str] = None) -> List[v1_models.ProviderEndpoint]:
async def list_provider_endpoints(
name: Optional[str] = None,
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
) -> List[v1_models.ProviderEndpoint]:
"""List all provider endpoints."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the provider endpoints from the database.
return [
v1_models.ProviderEndpoint(
id=1,
name="dummy",
description="Dummy provider endpoint",
endpoint="http://example.com",
provider_type=v1_models.ProviderType.openai,
auth_type=v1_models.ProviderAuthType.none,
)
]
try:
return pcrud.list_endpoints()
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")


@v1.get(
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
)
async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
async def get_provider_endpoint(
provider_id: int,
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
) -> v1_models.ProviderEndpoint:
"""Get a provider endpoint by ID."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the provider endpoint from the database.
Expand All @@ -65,7 +64,10 @@ async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
generate_unique_id_function=uniq_name,
status_code=201,
)
async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_models.ProviderEndpoint:
async def add_provider_endpoint(
request: v1_models.ProviderEndpoint,
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
) -> v1_models.ProviderEndpoint:
"""Add a provider endpoint."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that adds the provider endpoint to the database.
Expand All @@ -76,7 +78,9 @@ async def add_provider_endpoint(request: v1_models.ProviderEndpoint) -> v1_model
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
)
async def update_provider_endpoint(
provider_id: int, request: v1_models.ProviderEndpoint
provider_id: int,
request: v1_models.ProviderEndpoint,
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
) -> v1_models.ProviderEndpoint:
"""Update a provider endpoint by ID."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
Expand All @@ -87,7 +91,10 @@ async def update_provider_endpoint(
@v1.delete(
"/provider-endpoints/{provider_id}", tags=["Providers"], generate_unique_id_function=uniq_name
)
async def delete_provider_endpoint(provider_id: int):
async def delete_provider_endpoint(
provider_id: int,
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
):
"""Delete a provider endpoint by id."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that deletes the provider endpoint from the database.
Expand All @@ -99,7 +106,10 @@ async def delete_provider_endpoint(provider_id: int):
tags=["Providers"],
generate_unique_id_function=uniq_name,
)
async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByProvider]:
async def list_models_by_provider(
provider_name: str,
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
) -> List[v1_models.ModelByProvider]:
"""List models by provider."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches the models by provider from the database.
Expand All @@ -111,7 +121,9 @@ async def list_models_by_provider(provider_name: str) -> List[v1_models.ModelByP
tags=["Providers"],
generate_unique_id_function=uniq_name,
)
async def list_all_models_for_all_providers() -> List[v1_models.ModelByProvider]:
async def list_all_models_for_all_providers(
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
) -> List[v1_models.ModelByProvider]:
"""List all models for all providers."""
# NOTE: This is a dummy implementation. In the future, we should have a proper
# implementation that fetches all the models for all providers from the database.
Expand Down Expand Up @@ -394,7 +406,10 @@ async def delete_workspace_custom_instructions(workspace_name: str):
tags=["Workspaces", "Muxes"],
generate_unique_id_function=uniq_name,
)
async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
async def get_workspace_muxes(
workspace_name: str,
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
) -> List[v1_models.MuxRule]:
"""Get the mux rules of a workspace.

The list is ordered in order of priority. That is, the first rule in the list
Expand Down Expand Up @@ -422,7 +437,11 @@ async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
generate_unique_id_function=uniq_name,
status_code=204,
)
async def set_workspace_muxes(workspace_name: str, request: List[v1_models.MuxRule]):
async def set_workspace_muxes(
workspace_name: str,
request: List[v1_models.MuxRule],
pcrud: provendcrud.ProviderCrud = Depends(provendcrud.ProviderCrud),
):
"""Set the mux rules of a workspace."""
# TODO: This is a dummy implementation. In the future, we should have a proper
# implementation that sets the mux rules in the database.
Expand Down
14 changes: 13 additions & 1 deletion src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class ProviderType(str, Enum):
openai = "openai"
anthropic = "anthropic"
vllm = "vllm"
ollama = "ollama"
lm_studio = "lm_studio"


class ProviderAuthType(str, Enum):
Expand All @@ -163,19 +165,29 @@ class ProviderAuthType(str, Enum):
api_key = "api_key"


class ProviderEndpointSource(str, Enum):
"""
Represents the different sources of provider endpoints.
"""

config = "config"
db = "db"
Comment on lines +173 to +174
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably for a future PR but I think it's worth it to persist the provider information from config in DB. That way we can have a single source of truth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an excellent point for discussion. Are we planning to keep the provider configs? what do we do when a user changes the API endpoint from the config?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think from now on for muxing we should keep the provider configs

what do we do when a user changes the API endpoint from the config?

I don't know how common it is to change the API endpoint of an LLM provider. Right now, a user would need to stop CodeGate and start it again with new config in order to change the API endpoint.

Keeping the provider configs in DB and be able to change them without stopping the server gives a better UX IMO



class ProviderEndpoint(pydantic.BaseModel):
"""
Represents a provider's endpoint configuration. This
allows us to persist the configuration for each provider,
so we can use this for muxing messages.
"""

id: int
id: Optional[int] = None
name: str
description: str = ""
provider_type: ProviderType
endpoint: str
auth_type: ProviderAuthType
source: ProviderEndpointSource


class ModelByProvider(pydantic.BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/providers/crud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .crud import ProviderCrud

__all__ = ["ProviderCrud"]
78 changes: 78 additions & 0 deletions src/codegate/providers/crud/crud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import List, Optional
from urllib.parse import urlparse

import structlog
from pydantic import ValidationError

from codegate.api import v1_models as apimodelsv1
from codegate.config import Config
from codegate.db.connection import DbReader, DbRecorder

logger = structlog.get_logger("codegate")


class ProviderCrud:
"""The CRUD operations for the provider endpoint references within
Codegate.

This is meant to handle all the transformations in between the
database and the API, as well as other sources of information. All
operations should result in the API models being returned.
"""

def __init__(self):
self._db_reader = DbReader()
self._db_writer = DbRecorder()
config = Config.get_config()
if config is None:
logger.warning("OZZ: No configuration found.")
provided_urls = {}
else:
logger.info("OZZ: Using configuration for provider URLs.")
provided_urls = config.provider_urls

self._provider_urls = provided_urls

def list_endpoints(self) -> List[apimodelsv1.ProviderEndpoint]:
"""List all the endpoints."""

endpoints = []
for provider_name, provider_url in self._provider_urls.items():
provend = self.__provider_endpoint_from_cfg(provider_name, provider_url)
if provend is not None:
endpoints.append(provend)

return endpoints

def __provider_endpoint_from_cfg(
self, provider_name: str, provider_url: str
) -> Optional[apimodelsv1.ProviderEndpoint]:
"""Create a provider endpoint from the config entry."""

try:
_ = urlparse(provider_url)
except Exception:
logger.warning(
"Invalid provider URL", provider_name=provider_name, provider_url=provider_url
)
return None

try:
return apimodelsv1.ProviderEndpoint(
name=provider_name,
endpoint=provider_url,
descrption=("Endpoint for the {} provided via the CodeGate configuration.").format(
provider_name
),
provider_type=provider_name,
auth_type=apimodelsv1.ProviderAuthType.none,
source=apimodelsv1.ProviderEndpointSource.config,
)
except ValidationError as err:
logger.warning(
"Invalid provider name",
provider_name=provider_name,
provider_url=provider_url,
err=str(err),
)
return None
Loading