@@ -21,6 +21,14 @@ class ProviderNotFoundError(Exception):
2121 pass
2222
2323
24+ class ProviderModelsNotFoundError (Exception ):
25+ pass
26+
27+
28+ class ProviderInvalidAuthConfigError (Exception ):
29+ pass
30+
31+
2432class ProviderCrud :
2533 """The CRUD operations for the provider endpoint references within
2634 Codegate.
@@ -87,12 +95,12 @@ async def add_endpoint(
8795
8896 models = []
8997 if endpoint .auth_type == apimodelsv1 .ProviderAuthType .api_key and not endpoint .api_key :
90- raise ValueError ("API key must be provided for API auth type" )
98+ raise ProviderInvalidAuthConfigError ("API key must be provided for API auth type" )
9199 if endpoint .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
92100 try :
93101 models = prov .models (endpoint = endpoint .endpoint , api_key = endpoint .api_key )
94102 except Exception as err :
95- raise ValueError ( "Unable to get models from provider: {}" . format ( str ( err )) )
103+ raise ProviderModelsNotFoundError ( f "Unable to get models from provider: { err } " )
96104
97105 dbendpoint = await self ._db_writer .add_provider_endpoint (dbend )
98106
@@ -143,22 +151,14 @@ async def configure_auth_material(
143151 ):
144152 """Add an API key."""
145153 if config .auth_type == apimodelsv1 .ProviderAuthType .api_key and not config .api_key :
146- raise ValueError ("API key must be provided for API auth type" )
154+ raise ProviderInvalidAuthConfigError ("API key must be provided for API auth type" )
147155 elif config .auth_type != apimodelsv1 .ProviderAuthType .api_key and config .api_key :
148- raise ValueError ("API key provided for non-API auth type" )
156+ raise ProviderInvalidAuthConfigError ("API key provided for non-API auth type" )
149157
150158 dbendpoint = await self ._db_reader .get_provider_endpoint_by_id (str (provider_id ))
151159 if dbendpoint is None :
152160 raise ProviderNotFoundError ("Provider not found" )
153161
154- await self ._db_writer .push_provider_auth_material (
155- dbmodels .ProviderAuthMaterial (
156- provider_endpoint_id = dbendpoint .id ,
157- auth_type = config .auth_type ,
158- auth_blob = config .api_key if config .api_key else "" ,
159- )
160- )
161-
162162 endpoint = apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
163163 endpoint .auth_type = config .auth_type
164164 provider_registry = get_provider_registry ()
@@ -169,7 +169,15 @@ async def configure_auth_material(
169169 try :
170170 models = prov .models (endpoint = endpoint .endpoint , api_key = config .api_key )
171171 except Exception as err :
172- raise ValueError ("Unable to get models from provider: {}" .format (str (err )))
172+ raise ProviderModelsNotFoundError (f"Unable to get models from provider: { err } " )
173+
174+ await self ._db_writer .push_provider_auth_material (
175+ dbmodels .ProviderAuthMaterial (
176+ provider_endpoint_id = dbendpoint .id ,
177+ auth_type = config .auth_type ,
178+ auth_blob = config .api_key if config .api_key else "" ,
179+ )
180+ )
173181
174182 models_set = set (models )
175183
0 commit comments