11from typing import List , Optional
2+ from uuid import UUID
23
34import requests
45import structlog
5- from fastapi import APIRouter , HTTPException , Response
6+ from fastapi import APIRouter , Depends , HTTPException , Response
67from fastapi .responses import StreamingResponse
78from fastapi .routing import APIRoute
8- from pydantic import ValidationError
9+ from pydantic import BaseModel , ValidationError
910
1011from codegate import __version__
1112from codegate .api import v1_models , v1_processing
1213from codegate .db .connection import AlreadyExistsError , DbReader
14+ from codegate .providers import crud as provendcrud
1315from codegate .workspaces import crud
1416
1517logger = structlog .get_logger ("codegate" )
1618
1719v1 = APIRouter ()
1820wscrud = crud .WorkspaceCrud ()
21+ pcrud = provendcrud .ProviderCrud ()
1922
2023# This is a singleton object
2124dbreader = DbReader ()
@@ -25,38 +28,78 @@ def uniq_name(route: APIRoute):
2528 return f"v1_{ route .name } "
2629
2730
31+ class FilterByNameParams (BaseModel ):
32+ name : Optional [str ] = None
33+
34+
2835@v1 .get ("/provider-endpoints" , tags = ["Providers" ], generate_unique_id_function = uniq_name )
29- async def list_provider_endpoints (name : Optional [str ] = None ) -> List [v1_models .ProviderEndpoint ]:
36+ async def list_provider_endpoints (
37+ filter_query : FilterByNameParams = Depends (),
38+ ) -> List [v1_models .ProviderEndpoint ]:
3039 """List all provider endpoints."""
31- # NOTE: This is a dummy implementation. In the future, we should have a proper
32- # implementation that fetches the provider endpoints from the database.
33- return [
34- v1_models .ProviderEndpoint (
35- id = 1 ,
36- name = "dummy" ,
37- description = "Dummy provider endpoint" ,
38- endpoint = "http://example.com" ,
39- provider_type = v1_models .ProviderType .openai ,
40- auth_type = v1_models .ProviderAuthType .none ,
41- )
42- ]
40+ if filter_query .name is None :
41+ try :
42+ return await pcrud .list_endpoints ()
43+ except Exception :
44+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
45+
46+ try :
47+ provend = await pcrud .get_endpoint_by_name (filter_query .name )
48+ except Exception :
49+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
50+
51+ if provend is None :
52+ raise HTTPException (status_code = 404 , detail = "Provider endpoint not found" )
53+ return [provend ]
54+
55+
56+ # This needs to be above /provider-endpoints/{provider_id} to avoid conflict
57+ @v1 .get (
58+ "/provider-endpoints/models" ,
59+ tags = ["Providers" ],
60+ generate_unique_id_function = uniq_name ,
61+ )
62+ async def list_all_models_for_all_providers () -> List [v1_models .ModelByProvider ]:
63+ """List all models for all providers."""
64+ try :
65+ return await pcrud .get_all_models ()
66+ except Exception :
67+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
68+
69+
70+ @v1 .get (
71+ "/provider-endpoints/{provider_id}/models" ,
72+ tags = ["Providers" ],
73+ generate_unique_id_function = uniq_name ,
74+ )
75+ async def list_models_by_provider (
76+ provider_id : UUID ,
77+ ) -> List [v1_models .ModelByProvider ]:
78+ """List models by provider."""
79+
80+ try :
81+ return await pcrud .models_by_provider (provider_id )
82+ except provendcrud .ProviderNotFoundError :
83+ raise HTTPException (status_code = 404 , detail = "Provider not found" )
84+ except Exception as e :
85+ raise HTTPException (status_code = 500 , detail = str (e ))
4386
4487
4588@v1 .get (
4689 "/provider-endpoints/{provider_id}" , tags = ["Providers" ], generate_unique_id_function = uniq_name
4790)
48- async def get_provider_endpoint (provider_id : int ) -> v1_models .ProviderEndpoint :
91+ async def get_provider_endpoint (
92+ provider_id : UUID ,
93+ ) -> v1_models .ProviderEndpoint :
4994 """Get a provider endpoint by ID."""
50- # NOTE: This is a dummy implementation. In the future, we should have a proper
51- # implementation that fetches the provider endpoint from the database.
52- return v1_models .ProviderEndpoint (
53- id = provider_id ,
54- name = "dummy" ,
55- description = "Dummy provider endpoint" ,
56- endpoint = "http://example.com" ,
57- provider_type = v1_models .ProviderType .openai ,
58- auth_type = v1_models .ProviderAuthType .none ,
59- )
95+ try :
96+ provend = await pcrud .get_endpoint_by_id (provider_id )
97+ except Exception :
98+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
99+
100+ if provend is None :
101+ raise HTTPException (status_code = 404 , detail = "Provider endpoint not found" )
102+ return provend
60103
61104
62105@v1 .post (
@@ -65,59 +108,65 @@ async def get_provider_endpoint(provider_id: int) -> v1_models.ProviderEndpoint:
65108 generate_unique_id_function = uniq_name ,
66109 status_code = 201 ,
67110)
68- async def add_provider_endpoint (request : v1_models .ProviderEndpoint ) -> v1_models .ProviderEndpoint :
111+ async def add_provider_endpoint (
112+ request : v1_models .ProviderEndpoint ,
113+ ) -> v1_models .ProviderEndpoint :
69114 """Add a provider endpoint."""
70- # NOTE: This is a dummy implementation. In the future, we should have a proper
71- # implementation that adds the provider endpoint to the database.
72- return request
115+ try :
116+ provend = await pcrud .add_endpoint (request )
117+ except AlreadyExistsError :
118+ raise HTTPException (status_code = 409 , detail = "Provider endpoint already exists" )
119+ except ValidationError as e :
120+ # TODO: This should be more specific
121+ raise HTTPException (
122+ status_code = 400 ,
123+ detail = str (e ),
124+ )
125+ except Exception :
126+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
127+
128+ return provend
73129
74130
75131@v1 .put (
76132 "/provider-endpoints/{provider_id}" , tags = ["Providers" ], generate_unique_id_function = uniq_name
77133)
78134async def update_provider_endpoint (
79- provider_id : int , request : v1_models .ProviderEndpoint
135+ provider_id : UUID ,
136+ request : v1_models .ProviderEndpoint ,
80137) -> v1_models .ProviderEndpoint :
81138 """Update a provider endpoint by ID."""
82- # NOTE: This is a dummy implementation. In the future, we should have a proper
83- # implementation that updates the provider endpoint in the database.
84- return request
139+ try :
140+ request .id = provider_id
141+ provend = await pcrud .update_endpoint (request )
142+ except ValidationError as e :
143+ # TODO: This should be more specific
144+ raise HTTPException (
145+ status_code = 400 ,
146+ detail = str (e ),
147+ )
148+ except Exception :
149+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
150+
151+ return provend
85152
86153
87154@v1 .delete (
88155 "/provider-endpoints/{provider_id}" , tags = ["Providers" ], generate_unique_id_function = uniq_name
89156)
90- async def delete_provider_endpoint (provider_id : int ):
157+ async def delete_provider_endpoint (
158+ provider_id : UUID ,
159+ ):
91160 """Delete a provider endpoint by id."""
92- # NOTE: This is a dummy implementation. In the future, we should have a proper
93- # implementation that deletes the provider endpoint from the database.
161+ try :
162+ await pcrud .delete_endpoint (provider_id )
163+ except provendcrud .ProviderNotFoundError :
164+ raise HTTPException (status_code = 404 , detail = "Provider endpoint not found" )
165+ except Exception :
166+ raise HTTPException (status_code = 500 , detail = "Internal server error" )
94167 return Response (status_code = 204 )
95168
96169
97- @v1 .get (
98- "/provider-endpoints/{provider_name}/models" ,
99- tags = ["Providers" ],
100- generate_unique_id_function = uniq_name ,
101- )
102- async def list_models_by_provider (provider_name : str ) -> List [v1_models .ModelByProvider ]:
103- """List models by provider."""
104- # NOTE: This is a dummy implementation. In the future, we should have a proper
105- # implementation that fetches the models by provider from the database.
106- return [v1_models .ModelByProvider (name = "dummy" , provider = "dummy" )]
107-
108-
109- @v1 .get (
110- "/provider-endpoints/models" ,
111- tags = ["Providers" ],
112- generate_unique_id_function = uniq_name ,
113- )
114- async def list_all_models_for_all_providers () -> List [v1_models .ModelByProvider ]:
115- """List all models for all providers."""
116- # NOTE: This is a dummy implementation. In the future, we should have a proper
117- # implementation that fetches all the models for all providers from the database.
118- return [v1_models .ModelByProvider (name = "dummy" , provider = "dummy" )]
119-
120-
121170@v1 .get ("/workspaces" , tags = ["Workspaces" ], generate_unique_id_function = uniq_name )
122171async def list_workspaces () -> v1_models .ListWorkspacesResponse :
123172 """List all workspaces."""
@@ -394,7 +443,9 @@ async def delete_workspace_custom_instructions(workspace_name: str):
394443 tags = ["Workspaces" , "Muxes" ],
395444 generate_unique_id_function = uniq_name ,
396445)
397- async def get_workspace_muxes (workspace_name : str ) -> List [v1_models .MuxRule ]:
446+ async def get_workspace_muxes (
447+ workspace_name : str ,
448+ ) -> List [v1_models .MuxRule ]:
398449 """Get the mux rules of a workspace.
399450
400451 The list is ordered in order of priority. That is, the first rule in the list
@@ -422,7 +473,10 @@ async def get_workspace_muxes(workspace_name: str) -> List[v1_models.MuxRule]:
422473 generate_unique_id_function = uniq_name ,
423474 status_code = 204 ,
424475)
425- async def set_workspace_muxes (workspace_name : str , request : List [v1_models .MuxRule ]):
476+ async def set_workspace_muxes (
477+ workspace_name : str ,
478+ request : List [v1_models .MuxRule ],
479+ ):
426480 """Set the mux rules of a workspace."""
427481 # TODO: This is a dummy implementation. In the future, we should have a proper
428482 # implementation that sets the mux rules in the database.
0 commit comments