Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 3f0c914

Browse files
committed
endpoint to get full workspace config + free
1 parent da69ec0 commit 3f0c914

File tree

4 files changed

+250
-17
lines changed

4 files changed

+250
-17
lines changed

src/codegate/api/v1.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import requests
55
import structlog
6-
from fastapi import APIRouter, Depends, HTTPException, Response
6+
from fastapi import APIRouter, Depends, HTTPException, Query, Response
77
from fastapi.responses import StreamingResponse
88
from fastapi.routing import APIRoute
99
from pydantic import BaseModel, ValidationError
@@ -12,7 +12,7 @@
1212
from codegate import __version__
1313
from codegate.api import v1_models, v1_processing
1414
from codegate.db.connection import AlreadyExistsError, DbReader
15-
from codegate.db.models import AlertSeverity, WorkspaceWithModel
15+
from codegate.db.models import AlertSeverity
1616
from codegate.providers import crud as provendcrud
1717
from codegate.workspaces import crud
1818

@@ -209,13 +209,32 @@ async def delete_provider_endpoint(
209209

210210

211211
@v1.get("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name)
212-
async def list_workspaces() -> v1_models.ListWorkspacesResponse:
213-
"""List all workspaces."""
214-
wslist = await wscrud.get_workspaces()
212+
async def list_workspaces(
213+
provider_id: Optional[UUID] = Query(None),
214+
) -> v1_models.ListWorkspacesResponse:
215+
"""
216+
List all workspaces.
215217
216-
resp = v1_models.ListWorkspacesResponse.from_db_workspaces_with_sessioninfo(wslist)
218+
Args:
219+
provider_id (Optional[UUID]): Filter workspaces by provider ID. If provided,
220+
will return workspaces where models from the specified provider (e.g., OpenAI,
221+
Anthropic) have been used in workspace muxing rules. Note that you must
222+
refer to a provider by ID, not by name.
217223
218-
return resp
224+
Returns:
225+
ListWorkspacesResponse: A response object containing the list of workspaces.
226+
"""
227+
try:
228+
if provider_id:
229+
wslist = await wscrud.workspaces_by_provider(provider_id)
230+
resp = v1_models.ListWorkspacesResponse.from_db_workspaces(wslist)
231+
return resp
232+
else:
233+
wslist = await wscrud.get_workspaces()
234+
resp = v1_models.ListWorkspacesResponse.from_db_workspaces_with_sessioninfo(wslist)
235+
return resp
236+
except Exception as e:
237+
raise HTTPException(status_code=500, detail=str(e))
219238

220239

221240
@v1.get("/workspaces/active", tags=["Workspaces"], generate_unique_id_function=uniq_name)
@@ -584,17 +603,28 @@ async def set_workspace_muxes(
584603

585604

586605
@v1.get(
587-
"/workspaces/{provider_id}",
606+
"/workspaces/{workspace_name}",
588607
tags=["Workspaces"],
589608
generate_unique_id_function=uniq_name,
590609
)
591-
async def list_workspaces_by_provider(
592-
provider_id: UUID,
593-
) -> List[WorkspaceWithModel]:
610+
async def get_workspace_by_name(
611+
workspace_name: str,
612+
) -> v1_models.FullWorkspace:
594613
"""List workspaces by provider ID."""
595614
try:
596-
return await wscrud.workspaces_by_provider(provider_id)
615+
ws = await wscrud.get_workspace_by_name(workspace_name)
616+
muxes = await wscrud.get_muxes(workspace_name)
617+
618+
return v1_models.FullWorkspace(
619+
name=ws.name,
620+
config=v1_models.WorkspaceConfig(
621+
custom_instructions=ws.custom_instructions or "",
622+
muxing_rules=muxes,
623+
),
624+
)
597625

626+
except crud.WorkspaceDoesNotExistError:
627+
raise HTTPException(status_code=404, detail="Workspace does not exist")
598628
except Exception as e:
599629
raise HTTPException(status_code=500, detail=str(e))
600630

src/codegate/db/connection.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
ProviderModel,
3636
Session,
3737
WorkspaceRow,
38-
WorkspaceWithModel,
3938
WorkspaceWithSessionInfo,
4039
)
4140
from codegate.db.token_usage import TokenUsageParser
@@ -820,11 +819,13 @@ async def get_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]:
820819
)
821820
return workspaces[0] if workspaces else None
822821

823-
async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceWithModel]:
822+
async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceRow]:
824823
sql = text(
825824
"""
826825
SELECT
827-
w.id, w.name, m.provider_model_name
826+
w.id,
827+
w.name,
828+
w.custom_instructions
828829
FROM workspaces w
829830
JOIN muxes m ON w.id = m.workspace_id
830831
WHERE m.provider_endpoint_id = :provider_id
@@ -833,7 +834,7 @@ async def get_workspaces_by_provider(self, provider_id: str) -> List[WorkspaceWi
833834
)
834835
conditions = {"provider_id": provider_id}
835836
workspaces = await self._exec_select_conditions_to_pydantic(
836-
WorkspaceWithModel, sql, conditions, should_raise=True
837+
WorkspaceRow, sql, conditions, should_raise=True
837838
)
838839
return workspaces
839840

src/codegate/workspaces/crud.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,9 @@ async def get_workspace_by_name(self, workspace_name: str) -> db_models.Workspac
281281
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
282282
return workspace
283283

284-
async def workspaces_by_provider(self, provider_id: uuid) -> List[db_models.WorkspaceWithModel]:
284+
async def workspaces_by_provider(
285+
self, provider_id: uuid
286+
) -> List[db_models.WorkspaceWithSessionInfo]:
285287
"""Get the workspaces by provider."""
286288

287289
workspaces = await self._db_reader.get_workspaces_by_provider(str(provider_id))

tests/api/test_v1_workspaces.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,184 @@ def mock_pipeline_factory():
7070
return mock_factory
7171

7272

73+
@pytest.mark.asyncio
74+
async def test_get_workspaces(
75+
mock_pipeline_factory, mock_workspace_crud, mock_provider_crud
76+
) -> None:
77+
with (
78+
patch("codegate.api.v1.wscrud", mock_workspace_crud),
79+
patch("codegate.api.v1.pcrud", mock_provider_crud),
80+
patch(
81+
"codegate.providers.openai.provider.OpenAIProvider.models",
82+
return_value=["foo-bar-001", "foo-bar-002"],
83+
),
84+
):
85+
"""Test getting all workspaces."""
86+
app = init_app(mock_pipeline_factory)
87+
88+
async with AsyncClient(
89+
transport=httpx.ASGITransport(app=app), base_url="http://test"
90+
) as ac:
91+
# Create a provider for muxing rules
92+
provider_payload = {
93+
"name": "test-provider",
94+
"description": "",
95+
"auth_type": "none",
96+
"provider_type": "openai",
97+
"endpoint": "https://api.openai.com",
98+
"api_key": "sk-proj-foo-bar-123-xzy",
99+
}
100+
response = await ac.post("/api/v1/provider-endpoints", json=provider_payload)
101+
assert response.status_code == 201
102+
provider = response.json()
103+
104+
# Create first workspace
105+
name_1 = str(uuid())
106+
workspace_1 = {
107+
"name": name_1,
108+
"config": {
109+
"custom_instructions": "Respond in haiku format",
110+
"muxing_rules": [
111+
{
112+
"provider_id": provider["id"],
113+
"model": "foo-bar-001",
114+
"matcher": "*.py",
115+
"matcher_type": "filename_match",
116+
}
117+
],
118+
},
119+
}
120+
response = await ac.post("/api/v1/workspaces", json=workspace_1)
121+
assert response.status_code == 201
122+
123+
# Create second workspace
124+
name_2 = str(uuid())
125+
workspace_2 = {
126+
"name": name_2,
127+
"config": {
128+
"custom_instructions": "Respond in prose",
129+
"muxing_rules": [
130+
{
131+
"provider_id": provider["id"],
132+
"model": "foo-bar-002",
133+
"matcher": "*.js",
134+
"matcher_type": "filename_match",
135+
}
136+
],
137+
},
138+
}
139+
response = await ac.post("/api/v1/workspaces", json=workspace_2)
140+
assert response.status_code == 201
141+
142+
response = await ac.get("/api/v1/workspaces")
143+
assert response.status_code == 200
144+
workspaces = response.json()["workspaces"]
145+
146+
# Verify response structure
147+
assert isinstance(workspaces, list)
148+
assert len(workspaces) >= 2
149+
150+
workspace_names = [w["name"] for w in workspaces]
151+
assert name_1 in workspace_names
152+
assert name_2 in workspace_names
153+
assert len([n for n in workspace_names if n in [name_1, name_2]]) == 2
154+
155+
156+
@pytest.mark.asyncio
157+
async def test_get_workspaces_filter_by_provider(
158+
mock_pipeline_factory, mock_workspace_crud, mock_provider_crud
159+
) -> None:
160+
with (
161+
patch("codegate.api.v1.wscrud", mock_workspace_crud),
162+
patch("codegate.api.v1.pcrud", mock_provider_crud),
163+
patch(
164+
"codegate.providers.openai.provider.OpenAIProvider.models",
165+
return_value=["foo-bar-001", "foo-bar-002"],
166+
),
167+
):
168+
"""Test filtering workspaces by provider ID."""
169+
app = init_app(mock_pipeline_factory)
170+
171+
async with AsyncClient(
172+
transport=httpx.ASGITransport(app=app), base_url="http://test"
173+
) as ac:
174+
# Create first provider
175+
provider_payload_1 = {
176+
"name": "provider-1",
177+
"description": "",
178+
"auth_type": "none",
179+
"provider_type": "openai",
180+
"endpoint": "https://api.openai.com",
181+
"api_key": "sk-proj-foo-bar-123-xyz",
182+
}
183+
response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_1)
184+
assert response.status_code == 201
185+
provider_1 = response.json()
186+
187+
# Create second provider
188+
provider_payload_2 = {
189+
"name": "provider-2",
190+
"description": "",
191+
"auth_type": "none",
192+
"provider_type": "openai",
193+
"endpoint": "https://api.openai.com",
194+
"api_key": "sk-proj-foo-bar-456-xyz",
195+
}
196+
response = await ac.post("/api/v1/provider-endpoints", json=provider_payload_2)
197+
assert response.status_code == 201
198+
provider_2 = response.json()
199+
200+
# Create workspace using provider 1
201+
workspace_1 = {
202+
"name": str(uuid()),
203+
"config": {
204+
"custom_instructions": "Instructions 1",
205+
"muxing_rules": [
206+
{
207+
"provider_id": provider_1["id"],
208+
"model": "foo-bar-001",
209+
"matcher": "*.py",
210+
"matcher_type": "filename_match",
211+
}
212+
],
213+
},
214+
}
215+
response = await ac.post("/api/v1/workspaces", json=workspace_1)
216+
assert response.status_code == 201
217+
218+
# Create workspace using provider 2
219+
workspace_2 = {
220+
"name": str(uuid()),
221+
"config": {
222+
"custom_instructions": "Instructions 2",
223+
"muxing_rules": [
224+
{
225+
"provider_id": provider_2["id"],
226+
"model": "foo-bar-002",
227+
"matcher": "*.js",
228+
"matcher_type": "filename_match",
229+
}
230+
],
231+
},
232+
}
233+
response = await ac.post("/api/v1/workspaces", json=workspace_2)
234+
assert response.status_code == 201
235+
236+
# Test filtering by provider 1
237+
response = await ac.get(f"/api/v1/workspaces?provider_id={provider_1['id']}")
238+
assert response.status_code == 200
239+
workspaces = response.json()["workspaces"]
240+
assert len(workspaces) == 1
241+
assert workspaces[0]["name"] == workspace_1["name"]
242+
243+
# Test filtering by provider 2
244+
response = await ac.get(f"/api/v1/workspaces?provider_id={provider_2['id']}")
245+
assert response.status_code == 200
246+
workspaces = response.json()["workspaces"]
247+
assert len(workspaces) == 1
248+
assert workspaces[0]["name"] == workspace_2["name"]
249+
250+
73251
@pytest.mark.asyncio
74252
async def test_create_update_workspace_happy_path(
75253
mock_pipeline_factory, mock_workspace_crud, mock_provider_crud
@@ -146,6 +324,10 @@ async def test_create_update_workspace_happy_path(
146324

147325
response = await ac.post("/api/v1/workspaces", json=payload_create)
148326
assert response.status_code == 201
327+
328+
# Verify created workspace
329+
response = await ac.get(f"/api/v1/workspaces/{name_1}")
330+
assert response.status_code == 200
149331
response_body = response.json()
150332

151333
assert response_body["name"] == name_1
@@ -184,6 +366,10 @@ async def test_create_update_workspace_happy_path(
184366

185367
response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update)
186368
assert response.status_code == 201
369+
370+
# Verify updated workspace
371+
response = await ac.get(f"/api/v1/workspaces/{name_2}")
372+
assert response.status_code == 200
187373
response_body = response.json()
188374

189375
assert response_body["name"] == name_2
@@ -222,8 +408,15 @@ async def test_create_update_workspace_name_only(
222408
response = await ac.post("/api/v1/workspaces", json=payload_create)
223409
assert response.status_code == 201
224410
response_body = response.json()
411+
assert response_body["name"] == name_1
225412

413+
# Verify created workspace
414+
response = await ac.get(f"/api/v1/workspaces/{name_1}")
415+
assert response.status_code == 200
416+
response_body = response.json()
226417
assert response_body["name"] == name_1
418+
assert response_body["config"]["custom_instructions"] == ""
419+
assert response_body["config"]["muxing_rules"] == []
227420

228421
name_2: str = str(uuid())
229422

@@ -234,8 +427,15 @@ async def test_create_update_workspace_name_only(
234427
response = await ac.put(f"/api/v1/workspaces/{name_1}", json=payload_update)
235428
assert response.status_code == 201
236429
response_body = response.json()
430+
assert response_body["name"] == name_2
237431

432+
# Verify updated workspace
433+
response = await ac.get(f"/api/v1/workspaces/{name_2}")
434+
assert response.status_code == 200
435+
response_body = response.json()
238436
assert response_body["name"] == name_2
437+
assert response_body["config"]["custom_instructions"] == ""
438+
assert response_body["config"]["muxing_rules"] == []
239439

240440

241441
@pytest.mark.asyncio

0 commit comments

Comments
 (0)