Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def configure_auth_material(
)
async def update_provider_endpoint(
provider_id: UUID,
request: v1_models.AddProviderEndpointRequest,
request: v1_models.ProviderEndpoint,
) -> v1_models.ProviderEndpoint:
"""Update a provider endpoint by ID."""
try:
Expand Down
77 changes: 35 additions & 42 deletions src/codegate/providers/crud/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def add_endpoint(
return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

async def update_endpoint(
self, endpoint: apimodelsv1.AddProviderEndpointRequest
self, endpoint: apimodelsv1.ProviderEndpoint
) -> apimodelsv1.ProviderEndpoint:
"""Update an endpoint."""

Expand All @@ -134,12 +134,40 @@ async def update_endpoint(
if founddbe is None:
raise ProviderNotFoundError("Provider not found")

models = []
if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key:
dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())

return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

async def configure_auth_material(
self, provider_id: UUID, config: apimodelsv1.ConfigureAuthMaterial
):
"""Add an API key."""
if config.auth_type == apimodelsv1.ProviderAuthType.api_key and not config.api_key:
raise ValueError("API key must be provided for API auth type")
if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough:
elif config.auth_type != apimodelsv1.ProviderAuthType.api_key and config.api_key:
raise ValueError("API key provided for non-API auth type")

dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id))
if dbendpoint is None:
raise ProviderNotFoundError("Provider not found")

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=config.auth_type,
auth_blob=config.api_key if config.api_key else "",
)
)

endpoint = apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)
endpoint.auth_type = config.auth_type
provider_registry = get_provider_registry()
prov = endpoint.get_from_registry(provider_registry)

models = []
if config.auth_type != apimodelsv1.ProviderAuthType.passthrough:
try:
models = prov.models(endpoint=endpoint.endpoint, api_key=endpoint.api_key)
models = prov.models(endpoint=endpoint.endpoint, api_key=config.api_key)
except Exception as err:
raise ValueError("Unable to get models from provider: {}".format(str(err)))

Expand All @@ -154,56 +182,21 @@ async def update_endpoint(
for model in models_set - models_in_db_set:
await self._db_writer.add_provider_model(
dbmodels.ProviderModel(
provider_endpoint_id=founddbe.id,
provider_endpoint_id=dbendpoint.id,
name=model,
)
)

# Remove the models that are in the DB but not in the provider
for model in models_in_db_set - models_set:
await self._db_writer.delete_provider_model(
founddbe.id,
dbendpoint.id,
model,
)

dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model())

# If an API key was provided or we've changed the auth type, we update the auth material
if endpoint.auth_type != founddbe.auth_type or endpoint.api_key:
await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=endpoint.auth_type,
auth_blob=endpoint.api_key if endpoint.api_key else "",
)
)

# a model might have been deleted, let's repopulate the cache
await self._ws_crud.repopulate_mux_cache()

return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint)

async def configure_auth_material(
self, provider_id: UUID, config: apimodelsv1.ConfigureAuthMaterial
):
"""Add an API key."""
if config.auth_type == apimodelsv1.ProviderAuthType.api_key and not config.api_key:
raise ValueError("API key must be provided for API auth type")
elif config.auth_type != apimodelsv1.ProviderAuthType.api_key and config.api_key:
raise ValueError("API key provided for non-API auth type")

dbendpoint = await self._db_reader.get_provider_endpoint_by_id(str(provider_id))
if dbendpoint is None:
raise ProviderNotFoundError("Provider not found")

await self._db_writer.push_provider_auth_material(
dbmodels.ProviderAuthMaterial(
provider_endpoint_id=dbendpoint.id,
auth_type=config.auth_type,
auth_blob=config.api_key if config.api_key else "",
)
)

async def delete_endpoint(self, provider_id: UUID):
"""Delete an endpoint."""

Expand Down