|
11 | 11 | from codegate.db import models as dbmodels |
12 | 12 | from codegate.db.connection import DbReader, DbRecorder |
13 | 13 | from codegate.providers.base import BaseProvider |
14 | | -from codegate.providers.registry import ProviderRegistry |
| 14 | +from codegate.providers.registry import ProviderRegistry, get_provider_registry |
15 | 15 |
|
16 | 16 | logger = structlog.get_logger("codegate") |
17 | 17 |
|
@@ -62,22 +62,60 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider |
62 | 62 | return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) |
63 | 63 |
|
64 | 64 | async def add_endpoint( |
65 | | - self, endpoint: apimodelsv1.ProviderEndpoint |
| 65 | + self, endpoint: apimodelsv1.AddProviderEndpointRequest |
66 | 66 | ) -> apimodelsv1.ProviderEndpoint: |
67 | 67 | """Add an endpoint.""" |
68 | 68 | dbend = endpoint.to_db_model() |
| 69 | + provider_registry = get_provider_registry() |
69 | 70 |
|
70 | 71 | # We override the ID here, as we want to generate it. |
71 | 72 | dbend.id = str(uuid4()) |
72 | 73 |
|
73 | | - dbendpoint = await self._db_writer.add_provider_endpoint() |
| 74 | + prov = endpoint.get_from_registry(provider_registry) |
| 75 | + if prov is None: |
| 76 | + raise ValueError("Unknown provider type: {}".format(endpoint.provider_type)) |
| 77 | + |
| 78 | + models = [] |
| 79 | + if endpoint.auth_type == apimodelsv1.ProviderAuthType.api_key and not endpoint.api_key: |
| 80 | + raise ValueError("API key must be provided for API auth type") |
| 81 | + if endpoint.auth_type != apimodelsv1.ProviderAuthType.passthrough: |
| 82 | + try: |
| 83 | + models = prov.models(endpoint.api_key) |
| 84 | + except Exception as err: |
| 85 | + raise ValueError("Unable to get models from provider: {}".format(str(err))) |
| 86 | + |
| 87 | + dbendpoint = await self._db_writer.add_provider_endpoint(dbend) |
| 88 | + |
| 89 | + await self._db_writer.push_provider_auth_material( |
| 90 | + dbmodels.ProviderAuthMaterial( |
| 91 | + provider_endpoint_id=dbendpoint.id, |
| 92 | + auth_type=endpoint.auth_type, |
| 93 | + auth_blob=endpoint.api_key if endpoint.api_key else "", |
| 94 | + ) |
| 95 | + ) |
| 96 | + |
| 97 | + for model in models: |
| 98 | + await self._db_writer.add_provider_model( |
| 99 | + dbmodels.ProviderModel( |
| 100 | + provider_endpoint_id=dbendpoint.id, |
| 101 | + name=model, |
| 102 | + ) |
| 103 | + ) |
74 | 104 | return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) |
75 | 105 |
|
76 | 106 | async def update_endpoint( |
77 | 107 | self, endpoint: apimodelsv1.ProviderEndpoint |
78 | 108 | ) -> apimodelsv1.ProviderEndpoint: |
79 | 109 | """Update an endpoint.""" |
80 | 110 |
|
| 111 | + founddbe = await self._db_reader.get_provider_endpoint_by_id(endpoint.id) |
| 112 | + if founddbe is None: |
| 113 | + raise ProviderNotFoundError("Provider not found") |
| 114 | + |
| 115 | + # TODO: We should probably allow this and reinitialize the provider |
| 116 | + if endpoint.auth_type != founddbe.auth_type: |
| 117 | + raise ValueError("Cannot change auth type for provider through this endpoint.") |
| 118 | + |
81 | 119 | dbendpoint = await self._db_writer.update_provider_endpoint(endpoint.to_db_model()) |
82 | 120 | return apimodelsv1.ProviderEndpoint.from_db_model(dbendpoint) |
83 | 121 |
|
@@ -175,6 +213,13 @@ async def initialize_provider_endpoints(preg: ProviderRegistry): |
175 | 213 | continue |
176 | 214 |
|
177 | 215 | pimpl = provend.get_from_registry(preg) |
| 216 | + if pimpl is None: |
| 217 | + logger.warning( |
| 218 | + "Provider not found in registry", |
| 219 | + provider=provend.name, |
| 220 | + endpoint=provend.endpoint, |
| 221 | + ) |
| 222 | + continue |
178 | 223 | await try_initialize_provider_endpoints(provend, pimpl, db_writer) |
179 | 224 |
|
180 | 225 |
|
|
0 commit comments