1111from codegate .db import models as dbmodels
1212from codegate .db .connection import DbReader , DbRecorder
1313from codegate .providers .base import BaseProvider
14- from codegate .providers .registry import ProviderRegistry
14+ from codegate .providers .registry import ProviderRegistry , get_provider_registry
1515
1616logger = structlog .get_logger ("codegate" )
1717
@@ -62,23 +62,106 @@ async def get_endpoint_by_name(self, name: str) -> Optional[apimodelsv1.Provider
6262 return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
6363
6464 async def add_endpoint (
65- self , endpoint : apimodelsv1 .ProviderEndpoint
65+ self , endpoint : apimodelsv1 .AddProviderEndpointRequest
6666 ) -> apimodelsv1 .ProviderEndpoint :
6767 """Add an endpoint."""
68+
69+ if not endpoint .endpoint :
70+ endpoint .endpoint = provider_default_endpoints (endpoint .provider_type )
71+
72+ # If we STILL don't have an endpoint, we can't continue
73+ if not endpoint .endpoint :
74+ raise ValueError ("No endpoint provided and no default found for provider type" )
75+
6876 dbend = endpoint .to_db_model ()
77+ provider_registry = get_provider_registry ()
6978
7079 # We override the ID here, as we want to generate it.
7180 dbend .id = str (uuid4 ())
7281
73- dbendpoint = await self ._db_writer .add_provider_endpoint ()
82+ prov = endpoint .get_from_registry (provider_registry )
83+ if prov is None :
84+ raise ValueError ("Unknown provider type: {}" .format (endpoint .provider_type ))
85+
86+ models = []
87+ if endpoint .auth_type == apimodelsv1 .ProviderAuthType .api_key and not endpoint .api_key :
88+ raise ValueError ("API key must be provided for API auth type" )
89+ if endpoint .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
90+ try :
91+ models = prov .models (endpoint = endpoint .endpoint , api_key = endpoint .api_key )
92+ except Exception as err :
93+ raise ValueError ("Unable to get models from provider: {}" .format (str (err )))
94+
95+ dbendpoint = await self ._db_writer .add_provider_endpoint (dbend )
96+
97+ await self ._db_writer .push_provider_auth_material (
98+ dbmodels .ProviderAuthMaterial (
99+ provider_endpoint_id = dbendpoint .id ,
100+ auth_type = endpoint .auth_type ,
101+ auth_blob = endpoint .api_key if endpoint .api_key else "" ,
102+ )
103+ )
104+
105+ for model in models :
106+ await self ._db_writer .add_provider_model (
107+ dbmodels .ProviderModel (
108+ provider_endpoint_id = dbendpoint .id ,
109+ name = model ,
110+ )
111+ )
74112 return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
75113
76114 async def update_endpoint (
77- self , endpoint : apimodelsv1 .ProviderEndpoint
115+ self , endpoint : apimodelsv1 .AddProviderEndpointRequest
78116 ) -> apimodelsv1 .ProviderEndpoint :
79117 """Update an endpoint."""
80118
119+ if not endpoint .endpoint :
120+ endpoint .endpoint = provider_default_endpoints (endpoint .provider_type )
121+
122+ # If we STILL don't have an endpoint, we can't continue
123+ if not endpoint .endpoint :
124+ raise ValueError ("No endpoint provided and no default found for provider type" )
125+
126+ provider_registry = get_provider_registry ()
127+ prov = endpoint .get_from_registry (provider_registry )
128+ if prov is None :
129+ raise ValueError ("Unknown provider type: {}" .format (endpoint .provider_type ))
130+
131+ founddbe = await self ._db_reader .get_provider_endpoint_by_id (str (endpoint .id ))
132+ if founddbe is None :
133+ raise ProviderNotFoundError ("Provider not found" )
134+
135+ models = []
136+ if endpoint .auth_type == apimodelsv1 .ProviderAuthType .api_key and not endpoint .api_key :
137+ raise ValueError ("API key must be provided for API auth type" )
138+ if endpoint .auth_type != apimodelsv1 .ProviderAuthType .passthrough :
139+ try :
140+ models = prov .models (endpoint = endpoint .endpoint , api_key = endpoint .api_key )
141+ except Exception as err :
142+ raise ValueError ("Unable to get models from provider: {}" .format (str (err )))
143+
144+ # Reset all provider models.
145+ await self ._db_writer .delete_provider_models (str (endpoint .id ))
146+
147+ for model in models :
148+ await self ._db_writer .add_provider_model (
149+ dbmodels .ProviderModel (
150+ provider_endpoint_id = founddbe .id ,
151+ name = model ,
152+ )
153+ )
154+
81155 dbendpoint = await self ._db_writer .update_provider_endpoint (endpoint .to_db_model ())
156+
157+ await self ._db_writer .push_provider_auth_material (
158+ dbmodels .ProviderAuthMaterial (
159+ provider_endpoint_id = dbendpoint .id ,
160+ auth_type = endpoint .auth_type ,
161+ auth_blob = endpoint .api_key if endpoint .api_key else "" ,
162+ )
163+ )
164+
82165 return apimodelsv1 .ProviderEndpoint .from_db_model (dbendpoint )
83166
84167 async def configure_auth_material (
@@ -175,6 +258,13 @@ async def initialize_provider_endpoints(preg: ProviderRegistry):
175258 continue
176259
177260 pimpl = provend .get_from_registry (preg )
261+ if pimpl is None :
262+ logger .warning (
263+ "Provider not found in registry" ,
264+ provider = provend .name ,
265+ endpoint = provend .endpoint ,
266+ )
267+ continue
178268 await try_initialize_provider_endpoints (provend , pimpl , db_writer )
179269
180270
@@ -240,7 +330,7 @@ def __provider_endpoint_from_cfg(
240330 description = ("Endpoint for the {} provided via the CodeGate configuration." ).format (
241331 provider_name
242332 ),
243- provider_type = provider_name ,
333+ provider_type = provider_overrides ( provider_name ) ,
244334 auth_type = apimodelsv1 .ProviderAuthType .passthrough ,
245335 )
246336 except ValidationError as err :
@@ -251,3 +341,24 @@ def __provider_endpoint_from_cfg(
251341 err = str (err ),
252342 )
253343 return None
344+
345+
346+ def provider_default_endpoints (provider_type : str ) -> str :
347+ defaults = {
348+ "openai" : "https://api.openai.com" ,
349+ "anthropic" : "https://api.anthropic.com" ,
350+ }
351+
352+ # If we have a default, we return it
353+ # Otherwise, we return an empty string
354+ return defaults .get (provider_type , "" )
355+
356+
357+ def provider_overrides (provider_type : str ) -> str :
358+ overrides = {
359+ "lm_studio" : "openai" ,
360+ }
361+
362+ # If we have an override, we return it
363+ # Otherwise, we return the type
364+ return overrides .get (provider_type , provider_type )
0 commit comments