Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ async def add_provider_endpoint(
status_code=400,
detail=str(e),
)
except provendcrud.ProviderModelsNotFoundError:
raise HTTPException(status_code=401, detail="Provider models could not be found")
except provendcrud.ProviderInvalidAuthConfigError:
raise HTTPException(status_code=400, detail="Invalid auth configuration")
except ValidationError as e:
# TODO: This should be more specific
raise HTTPException(
Expand Down Expand Up @@ -151,6 +155,10 @@ async def configure_auth_material(
await pcrud.configure_auth_material(provider_id, request)
except provendcrud.ProviderNotFoundError:
raise HTTPException(status_code=404, detail="Provider endpoint not found")
except provendcrud.ProviderModelsNotFoundError:
raise HTTPException(status_code=401, detail="Provider models could not be found")
except provendcrud.ProviderInvalidAuthConfigError:
raise HTTPException(status_code=400, detail="Invalid auth configuration")
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

Expand Down
2 changes: 1 addition & 1 deletion src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class ProviderEndpoint(pydantic.BaseModel):
description: str = ""
provider_type: db_models.ProviderType
endpoint: str = "" # Some providers have defaults we can leverage
auth_type: Optional[ProviderAuthType] = ProviderAuthType.none
auth_type: ProviderAuthType = ProviderAuthType.none

@staticmethod
def from_db_model(db_model: db_models.ProviderEndpoint) -> "ProviderEndpoint":
Expand Down
16 changes: 14 additions & 2 deletions src/codegate/providers/crud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
from .crud import ProviderCrud, ProviderNotFoundError, initialize_provider_endpoints
from .crud import (
ProviderCrud,
ProviderInvalidAuthConfigError,
ProviderModelsNotFoundError,
ProviderNotFoundError,
initialize_provider_endpoints,
)

__all__ = ["ProviderCrud", "initialize_provider_endpoints", "ProviderNotFoundError"]
__all__ = [
"ProviderCrud",
"initialize_provider_endpoints",
"ProviderNotFoundError",
"ProviderModelsNotFoundError",
"ProviderInvalidAuthConfigError",
]
34 changes: 21 additions & 13 deletions src/codegate/providers/crud/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ class ProviderNotFoundError(Exception):
pass


class ProviderModelsNotFoundError(Exception):
pass


class ProviderInvalidAuthConfigError(Exception):
pass


class ProviderCrud:
"""The CRUD operations for the provider endpoint references within
Codegate.
Expand Down Expand Up @@ -87,12 +95,12 @@ async def add_endpoint(

models = []
if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key:
raise ValueError("API key must be provided for API auth type")
raise ProviderInvalidAuthConfigError("API key must be provided for API auth type")
if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough:
try:
models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key)
except Exception as err:
raise ValueError("Unable to get models from provider: {}".format(str(err)))
raise ProviderModelsNotFoundError(f"Unable to get models from provider: {err}")

dbendpoint = await self._db_writer.add_provider_endpoint(dbend)

Expand Down Expand Up @@ -143,22 +151,14 @@ async def configure_auth_material(
):
"""Add an API key."""
if config.auth_type == apimodelsv1.ProviderAuthType.api_key and not config.api_key:
raise ValueError("API key must be provided for API auth type")
raise ProviderInvalidAuthConfigError("API key must be provided for API auth type")
elif config.auth_type != apimodelsv1.ProviderAuthType.api_key and config.api_key:
raise ValueError("API key provided for non-API auth type")
raise ProviderInvalidAuthConfigError("API key provided for non-API auth type")

dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id))
if dbendpoint is None:
raise ProviderNotFoundError("Provider not found")

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=config.auth_type,
auth_blob=config.api_key if config.api_key else "",
)
)

endpoint = apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
endpoint.auth_type = config.auth_type
provider_registry = get_provider_registry()
Expand All @@ -169,7 +169,15 @@ async def configure_auth_material(
try:
models = prov.models(endpoint=endpoint.endpoint, api_key=config.api_key)
except Exception as err:
raise ValueError("Unable to get models from provider: {}".format(str(err)))
raise ProviderModelsNotFoundError(f"Unable to get models from provider: {err}")

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=config.auth_type,
auth_blob=config.api_key if config.api_key else "",
)
)

models_set = set(models)

Expand Down