Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "e4c05d7591a8"
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from codegate.config import Config, ConfigurationError
from codegate.db.connection import (
init_db_sync,
init_session_if_not_exists,
init_instance,
init_session_if_not_exists,
)
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ async def init_instance(self) -> None:
await self._execute_with_no_return(sql, instance.model_dump())
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Instance already initialized.")
raise AlreadyExistsError("Instance already initialized.")


class DbReader(DbCodeGate):
Expand Down
47 changes: 37 additions & 10 deletions src/codegate/muxing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from codegate.muxing.adapter import BodyAdapter, ResponseAdapter
from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.registry import ProviderRegistry
from codegate.workspaces.crud import WorkspaceCrud
from codegate.workspaces.crud import WorkspaceCrud, WorkspaceDoesNotExistError

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -40,23 +40,47 @@ def _ensure_path_starts_with_slash(self, path: str) -> str:
return path if path.startswith("/") else f"/{path}"

async def _get_model_route(
self, thing_to_match: mux_models.ThingToMatchMux
self, thing_to_match: mux_models.ThingToMatchMux, workspace_name: Optional[str] = None
) -> Optional[rulematcher.ModelRoute]:
"""
Get the model route for the given things_to_match.

If workspace_name is provided and exists, use that workspace.
Otherwise, use the active workspace.
"""
mux_registry = await rulematcher.get_muxing_rules_registry()
try:
# Try to get a model route for the active workspace
model_route = await mux_registry.get_match_for_active_workspace(thing_to_match)
return model_route
mux_registry = await rulematcher.get_muxing_rules_registry()
relevant_workspace = await self._get_relevant_workspace_name(
mux_registry, workspace_name
)
return await mux_registry.get_match_for_workspace(relevant_workspace, thing_to_match)
except rulematcher.MuxMatchingError as e:
logger.exception(f"Error matching rule and getting model route: {e}")
raise HTTPException(detail=str(e), status_code=404)
except Exception as e:
logger.exception(f"Error getting active workspace muxes: {e}")
logger.exception(f"Error getting workspace muxes: {e}")
raise HTTPException(detail=str(e), status_code=404)

async def _get_relevant_workspace_name(
self, mreg: rulematcher.MuxingRulesinWorkspaces, workspace_name: Optional[str]
) -> str:
if not workspace_name:
# No workspace specified, use active workspace
return mreg.get_active_workspace()

try:
# Verify the requested workspace exists
# TODO: We should have an in-memory cache of the workspaces
await self._ws_crud.get_workspace_by_name(workspace_name)
logger.debug(f"Using workspace from X-CodeGate-Workspace header: {workspace_name}")
return workspace_name
except WorkspaceDoesNotExistError:
# Workspace doesn't exist, fall back to active workspace
logger.warning(
f"Workspace {workspace_name} does not exist, falling back to active workspace"
)
return mreg.get_active_workspace()

def _setup_routes(self):

@self.router.post(f"/{self.route_name}/{{rest_of_path:path}}")
Expand All @@ -68,7 +92,7 @@ async def route_to_dest_provider(
"""
Route the request to the correct destination provider.

1. Get destination provider from DB and active workspace.
1. Get destination provider from DB and workspace (from header or active).
2. Map the request body to the destination provider format.
3. Run pipeline. Selecting the correct destination provider.
4. Transmit the response back to the client in OpenAI format.
Expand All @@ -78,14 +102,17 @@ async def route_to_dest_provider(
data = json.loads(body)
is_fim_request = FIMAnalyzer.is_fim_request(rest_of_path, data)

# 1. Get destination provider from DB and active workspace.
# Check if X-CodeGate-Workspace header is present
workspace_header = request.headers.get("X-CodeGate-Workspace")

# 1. Get destination provider from DB and workspace (from header or active).
thing_to_match = mux_models.ThingToMatchMux(
body=data,
url_request_path=rest_of_path,
is_fim_request=is_fim_request,
client_type=request.state.detected_client,
)
model_route = await self._get_model_route(thing_to_match)
model_route = await self._get_model_route(thing_to_match, workspace_header)
if not model_route:
raise HTTPException(
detail="No matching rule found for the active workspace", status_code=404
Expand Down
16 changes: 10 additions & 6 deletions src/codegate/muxing/rulematcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MuxMatchingError(Exception):
pass


async def get_muxing_rules_registry():
async def get_muxing_rules_registry() -> "MuxingRulesinWorkspaces":
"""Returns a singleton instance of the muxing rules registry."""

global _muxrules_sgtn
Expand Down Expand Up @@ -199,23 +199,27 @@ async def set_active_workspace(self, workspace_name: str) -> None:
"""Set the active workspace."""
self._active_workspace = workspace_name

def get_active_workspace(self) -> str:
"""Get the active workspace."""
return self._active_workspace

async def get_registries(self) -> List[str]:
"""Get the list of workspaces."""
async with self._lock:
return list(self._ws_rules.keys())

async def get_match_for_active_workspace(
self, thing_to_match: mux_models.ThingToMatchMux
async def get_match_for_workspace(
self, workspace_name: str, thing_to_match: mux_models.ThingToMatchMux
) -> Optional[ModelRoute]:
"""Get the first match for the given thing_to_match."""
"""Get the first match for the given thing_to_match in the specified workspace."""

# We iterate over all the rules and return the first match
# Since we already do a deepcopy in __getitem__, we don't need to lock here
try:
rules = await self.get_ws_rules(self._active_workspace)
rules = await self.get_ws_rules(workspace_name)
for rule in rules:
if rule.match(thing_to_match):
return rule.destination()
return None
except KeyError:
raise RuntimeError("No rules found for the active workspace")
raise RuntimeError(f"No rules found for workspace {workspace_name}")
Loading