@@ -114,7 +114,7 @@ async def add_endpoint(
114114 return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
115115
116116 async def update_endpoint (
117- self , endpoint : apimodelsv1 .AddProviderEndpointRequest
117+ self , endpoint : apimodelsv1 .ProviderEndpoint
118118 ) -> apimodelsv1 .ProviderEndpoint :
119119 """Update an endpoint."""
120120
@@ -134,12 +134,40 @@ async def update_endpoint(
134134 if founddbe is None :
135135 raise ProviderNotFoundError ("Provider not found" )
136136
137- models = []
138- if endpoint .auth_type == apimodelsv1 .ProviderAuthType .api_key and not endpoint .api_key :
137+ dbendpoint = await self ._db_writer .update_provider_endpoint (endpoint .to_db_model ())
138+
139+ return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
140+
141+ async def configure_auth_material (
142+ self , provider_id : UUID , config : apimodelsv1 .ConfigureAuthMaterial
143+ ):
144+ """Add an API key."""
145+ if config .auth_type == apimodelsv1 .ProviderAuthType .api_key and not config .api_key :
139146 raise ValueError ("API key must be provided for API auth type" )
140- if endpoint .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
147+ elif config .auth_type != apimodelsv1 .ProviderAuthType .api_key and config .api_key :
148+ raise ValueError ("API key provided for non-API auth type" )
149+
150+ dbendpoint = await self ._db_reader .get_provider_endpoint_by_id (str (provider_id ))
151+ if dbendpoint is None :
152+ raise ProviderNotFoundError ("Provider not found" )
153+
154+ await self ._db_writer .push_provider_auth_material (
155+ dbmodels .ProviderAuthMaterial (
156+ provider_endpoint_id = dbendpoint .id ,
157+ auth_type = config .auth_type ,
158+ auth_blob = config .api_key if config .api_key else "" ,
159+ )
160+ )
161+
162+ endpoint = apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
163+ endpoint .auth_type = config .auth_type
164+ provider_registry = get_provider_registry ()
165+ prov = endpoint .get_from_registry (provider_registry )
166+
167+ models = []
168+ if config .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
141169 try :
142- models = prov .models (endpoint = endpoint .endpoint , api_key = endpoint .api_key )
170+ models = prov .models (endpoint = endpoint .endpoint , api_key = config .api_key )
143171 except Exception as err :
144172 raise ValueError ("Unable to get models from provider: {}" .format (str (err )))
145173
@@ -154,56 +182,21 @@ async def update_endpoint(
154182 for model in models_set - models_in_db_set :
155183 await self ._db_writer .add_provider_model (
156184 dbmodels .ProviderModel (
157- provider_endpoint_id = founddbe .id ,
185+ provider_endpoint_id = dbendpoint .id ,
158186 name = model ,
159187 )
160188 )
161189
162190 # Remove the models that are in the DB but not in the provider
163191 for model in models_in_db_set - models_set :
164192 await self ._db_writer .delete_provider_model (
165- founddbe .id ,
193+ dbendpoint .id ,
166194 model ,
167195 )
168196
169- dbendpoint = await self ._db_writer .update_provider_endpoint (endpoint .to_db_model ())
170-
171- # If an API key was provided or we've changed the auth type, we update the auth material
172- if endpoint .auth_type != founddbe .auth_type or endpoint .api_key :
173- await self ._db_writer .push_provider_auth_material (
174- dbmodels .ProviderAuthMaterial (
175- provider_endpoint_id = dbendpoint .id ,
176- auth_type = endpoint .auth_type ,
177- auth_blob = endpoint .api_key if endpoint .api_key else "" ,
178- )
179- )
180-
181197 # a model might have been deleted, let's repopulate the cache
182198 await self ._ws_crud .repopulate_mux_cache ()
183199
184- return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
185-
186- async def configure_auth_material (
187- self , provider_id : UUID , config : apimodelsv1 .ConfigureAuthMaterial
188- ):
189- """Add an API key."""
190- if config .auth_type == apimodelsv1 .ProviderAuthType .api_key and not config .api_key :
191- raise ValueError ("API key must be provided for API auth type" )
192- elif config .auth_type != apimodelsv1 .ProviderAuthType .api_key and config .api_key :
193- raise ValueError ("API key provided for non-API auth type" )
194-
195- dbendpoint = await self ._db_reader .get_provider_endpoint_by_id (str (provider_id ))
196- if dbendpoint is None :
197- raise ProviderNotFoundError ("Provider not found" )
198-
199- await self ._db_writer .push_provider_auth_material (
200- dbmodels .ProviderAuthMaterial (
201- provider_endpoint_id = dbendpoint .id ,
202- auth_type = config .auth_type ,
203- auth_blob = config .api_key if config .api_key else "" ,
204- )
205- )
206-
207200 async def delete_endpoint (self , provider_id : UUID ):
208201 """Delete an endpoint."""
209202
0 commit comments