Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import pydantic

from codegate.clients.clients import ClientType


class MuxMatcherType(str, Enum):
"""
Expand All @@ -11,6 +13,12 @@ class MuxMatcherType(str, Enum):

# Always match this prompt
catch_all = "catch_all"
# Match based on the filename. It will match if there is a filename
# in the request that matches the matcher either extension or full name (*.py or main.py)
filename_match = "filename_match"
# Match based on the request type. It will match if the request type
# matches the matcher (e.g. FIM or chat)
request_type_match = "request_type_match"


class MuxRule(pydantic.BaseModel):
Expand All @@ -25,3 +33,14 @@ class MuxRule(pydantic.BaseModel):
# The actual matcher to use. Note that
# this depends on the matcher type.
matcher: Optional[str] = None


class ThingToMatchMux(pydantic.BaseModel):
"""
Represents the fields we can use to match a mux rule.
"""

body: dict
url_request_path: str
is_fim_request: bool
client_type: ClientType
83 changes: 37 additions & 46 deletions src/codegate/muxing/router.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
from typing import Optional

import structlog
from fastapi import APIRouter, HTTPException, Request

from codegate.clients.clients import ClientType
from codegate.clients.detector import DetectClient
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
from codegate.muxing import models as mux_models
from codegate.muxing import rulematcher
from codegate.muxing.adapter import BodyAdapter, ResponseAdapter
from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.registry import ProviderRegistry
from codegate.workspaces.crud import WorkspaceCrud

Expand Down Expand Up @@ -39,40 +39,20 @@ def get_routes(self) -> APIRouter:
def _ensure_path_starts_with_slash(self, path: str) -> str:
return path if path.startswith("/") else f"/{path}"

def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
async def _get_model_route(
self, thing_to_match: mux_models.ThingToMatchMux
) -> Optional[rulematcher.ModelRoute]:
"""
Extract filenames from the request data.
Get the model route for the given things_to_match.
"""
try:
body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client)
return body_extractor.extract_unique_filenames(data)
except BodyCodeSnippetExtractorError as e:
logger.error(f"Error extracting filenames from request: {e}")
return set()

async def _get_model_routes(self, filenames: set[str]) -> list[rulematcher.ModelRoute]:
"""
Get the model routes for the given filenames.
"""
model_routes = []
mux_registry = await rulematcher.get_muxing_rules_registry()
try:
# Try to get a catch_all route
single_model_route = await mux_registry.get_match_for_active_workspace(
thing_to_match=None
)
model_routes.append(single_model_route)

# Get the model routes for each filename
for filename in filenames:
model_route = await mux_registry.get_match_for_active_workspace(
thing_to_match=filename
)
model_routes.append(model_route)
# Try to get a model route for the active workspace
model_route = await mux_registry.get_match_for_active_workspace(thing_to_match)
return model_route
except Exception as e:
logger.error(f"Error getting active workspace muxes: {e}")
raise HTTPException(str(e), status_code=404)
return model_routes

def _setup_routes(self):

Expand All @@ -88,34 +68,45 @@ async def route_to_dest_provider(
1. Get destination provider from DB and active workspace.
2. Map the request body to the destination provider format.
3. Run pipeline. Selecting the correct destination provider.
4. Transmit the response back to the client in the correct format.
4. Transmit the response back to the client in OpenAI format.
"""

body = await request.body()
data = json.loads(body)
is_fim_request = FIMAnalyzer.is_fim_request(rest_of_path, data)

# 1. Get destination provider from DB and active workspace.
thing_to_match = mux_models.ThingToMatchMux(
body=data,
url_request_path=rest_of_path,
is_fim_request=is_fim_request,
client_type=request.state.detected_client,
)
model_route = await self._get_model_route(thing_to_match)
if not model_route:
raise HTTPException(
"No matching rule found for the active workspace", status_code=404
)

filenames_in_data = self._extract_request_filenames(request.state.detected_client, data)
logger.info(f"Extracted filenames from request: {filenames_in_data}")

model_routes = await self._get_model_routes(filenames_in_data)
if not model_routes:
raise HTTPException("No rule found for the active workspace", status_code=404)

# We still need some logic here to handle the case where we have multiple model routes.
# For the moment since we match all only pick the first.
model_route = model_routes[0]
logger.info(
"Muxing request routed to destination provider",
model=model_route.model.name,
provider_type=model_route.endpoint.provider_type,
provider_name=model_route.endpoint.name,
)

# Parse the input data and map it to the destination provider format
# 2. Map the request body to the destination provider format.
rest_of_path = self._ensure_path_starts_with_slash(rest_of_path)
new_data = self._body_adapter.map_body_to_dest(model_route, data)

# 3. Run pipeline. Selecting the correct destination provider.
provider = self._provider_registry.get_provider(model_route.endpoint.provider_type)
api_key = model_route.auth_material.auth_blob

# Send the request to the destination provider. It will run the pipeline
response = await provider.process_request(
new_data, api_key, rest_of_path, request.state.detected_client
new_data, api_key, is_fim_request, request.state.detected_client
)
# Format the response to the client always using the OpenAI format

# 4. Transmit the response back to the client in OpenAI format.
return self._response_adapter.format_response_to_client(
response, model_route.endpoint.provider_type
)
86 changes: 78 additions & 8 deletions src/codegate/muxing/rulematcher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import copy
from abc import ABC, abstractmethod
from asyncio import Lock
from typing import List, Optional
from typing import Dict, List, Optional

import structlog

from codegate.clients.clients import ClientType
from codegate.db import models as db_models
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
from codegate.muxing import models as mux_models

logger = structlog.get_logger("codegate")

_muxrules_sgtn = None

Expand Down Expand Up @@ -40,11 +48,12 @@ def __init__(
class MuxingRuleMatcher(ABC):
"""Base class for matching muxing rules."""

def __init__(self, route: ModelRoute):
def __init__(self, route: ModelRoute, matcher_blob: str):
self._route = route
self._matcher_blob = matcher_blob

@abstractmethod
def match(self, thing_to_match) -> bool:
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""Return True if the rule matches the thing_to_match."""
pass

Expand All @@ -61,23 +70,82 @@ class MuxingMatcherFactory:
def create(mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatcher:
"""Create a muxing matcher for the given endpoint and model."""

factory = {
"catch_all": CatchAllMuxingRuleMatcher,
factory: Dict[mux_models.MuxMatcherType, MuxingRuleMatcher] = {
mux_models.MuxMatcherType.catch_all: CatchAllMuxingRuleMatcher,
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
mux_models.MuxMatcherType.request_type_match: RequestTypeMuxingRuleMatcher,
}

try:
return factory[mux_rule.matcher_type](route)
# Initialize the MuxingRuleMatcher
return factory[mux_rule.matcher_type](route, mux_rule.matcher_blob)
except KeyError:
raise ValueError(f"Unknown matcher type: {mux_rule.matcher_type}")


class CatchAllMuxingRuleMatcher(MuxingRuleMatcher):
"""A catch all muxing rule matcher."""

def match(self, thing_to_match) -> bool:
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
logger.info("Catch all rule matched")
return True


class FileMuxingRuleMatcher(MuxingRuleMatcher):
"""A file muxing rule matcher."""

def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]:
"""
Extract filenames from the request data.
"""
try:
body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client)
return body_extractor.extract_unique_filenames(data)
except BodyCodeSnippetExtractorError as e:
logger.error(f"Error extracting filenames from request: {e}")
return set()

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Retun True if there is a filename in the request that matches the matcher_blob.
The matcher_blob is either an extension (e.g. .py) or a filename (e.g. main.py).
"""
# If there is no matcher_blob, we don't match
if not self._matcher_blob:
return False
filenames_to_match = self._extract_request_filenames(
thing_to_match.client_type, thing_to_match.body
)
is_filename_match = any(self._matcher_blob in filename for filename in filenames_to_match)
if is_filename_match:
logger.info(
"Filename rule matched", filenames=filenames_to_match, matcher=self._matcher_blob
)
return is_filename_match


class RequestTypeMuxingRuleMatcher(MuxingRuleMatcher):
"""A catch all muxing rule matcher."""

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Return True if the request type matches the matcher_blob.
The matcher_blob is either "fim" or "chat".
"""
# If there is no matcher_blob, we don't match
if not self._matcher_blob:
return False
incoming_request_type = "fim" if thing_to_match.is_fim_request else "chat"
is_request_type_match = self._matcher_blob == incoming_request_type
if is_request_type_match:
logger.info(
"Request type rule matched",
matcher=self._matcher_blob,
request_type=incoming_request_type,
)
return is_request_type_match


class MuxingRulesinWorkspaces:
"""A thread safe dictionary to store the muxing rules in workspaces."""

Expand Down Expand Up @@ -111,7 +179,9 @@ async def get_registries(self) -> List[str]:
async with self._lock:
return list(self._ws_rules.keys())

async def get_match_for_active_workspace(self, thing_to_match) -> Optional[ModelRoute]:
async def get_match_for_active_workspace(
self, thing_to_match: mux_models.ThingToMatchMux
) -> Optional[ModelRoute]:
"""Get the first match for the given thing_to_match."""

# We iterate over all the rules and return the first match
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/pipeline/system_prompt/codegate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

from codegate.clients.clients import ClientType
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage

from codegate.clients.clients import ClientType
from codegate.pipeline.base import (
PipelineContext,
PipelineResult,
Expand Down
7 changes: 4 additions & 3 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
from codegate.providers.base import BaseProvider, ModelFetchError
from codegate.providers.fim_analyzer import FIMAnalyzer
from codegate.providers.litellmshim import anthropic_stream_generator


Expand Down Expand Up @@ -57,10 +58,9 @@ async def process_request(
self,
data: dict,
api_key: str,
request_url_path: str,
is_fim_request: bool,
client_type: ClientType,
):
is_fim_request = self._is_fim_request(request_url_path, data)
try:
stream = await self.complete(data, api_key, is_fim_request, client_type)
except Exception as e:
Expand Down Expand Up @@ -98,10 +98,11 @@ async def create_message(

body = await request.body()
data = json.loads(body)
is_fim_request = FIMAnalyzer.is_fim_request(request.url.path, data)

return await self.process_request(
data,
x_api_key,
request.url.path,
is_fim_request,
request.state.detected_client,
)
Loading