Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 9dac0af

Browse files
Added a class which performs semantic routing
Related to: #1055 For the current implementation of muxing we only need to match a single Persona at a time. For example: 1. mux1 -> persona Architect -> openai o1 2. mux2 -> catch all -> openai gpt4o In the above case we would only need to know if the request matches the persona `Architect`. It's not needed to match any extra personas even if they exist in DB. This PR introduces what's necessary to do the above without actually wiring in muxing rules. The PR: - Creates the persona table in DB - Adds methods to write and read to the new persona table - Implements a function to check if a query matches to the specified persona To check more about the personas and the queries please check the unit tests
1 parent 809c24a commit 9dac0af

File tree

6 files changed

+910
-2
lines changed

6 files changed

+910
-2
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""add persona table
2+
3+
Revision ID: 02b710eda156
4+
Revises: 5e5cd2288147
5+
Create Date: 2025-03-03 10:08:16.206617+00:00
6+
7+
"""
8+
9+
from typing import Sequence, Union
10+
11+
from alembic import op
12+
13+
# revision identifiers, used by Alembic.
14+
revision: str = "02b710eda156"
15+
down_revision: Union[str, None] = "5e5cd2288147"
16+
branch_labels: Union[str, Sequence[str], None] = None
17+
depends_on: Union[str, Sequence[str], None] = None
18+
19+
20+
def upgrade() -> None:
21+
# Begin transaction
22+
op.execute("BEGIN TRANSACTION;")
23+
24+
op.execute(
25+
"""
26+
CREATE TABLE IF NOT EXISTS personas (
27+
id TEXT PRIMARY KEY, -- UUID stored as TEXT
28+
name TEXT NOT NULL UNIQUE,
29+
description TEXT NOT NULL,
30+
description_embedding BLOB NOT NULL
31+
);
32+
"""
33+
)
34+
35+
# Finish transaction
36+
op.execute("COMMIT;")
37+
38+
39+
def downgrade() -> None:
40+
# Begin transaction
41+
op.execute("BEGIN TRANSACTION;")
42+
43+
op.execute(
44+
"""
45+
DROP TABLE personas;
46+
"""
47+
)
48+
49+
# Finish transaction
50+
op.execute("COMMIT;")

src/codegate/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class Config:
5757
force_certs: bool = False
5858

5959
max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes.
60+
persona_threshold = 0.75 # Min value is 0 (max similarity), max value is 2 (orthogonal)
6061

6162
# Provider URLs with defaults
6263
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())

src/codegate/db/connection.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import asyncio
22
import json
3+
import sqlite3
34
import uuid
45
from pathlib import Path
56
from typing import Dict, List, Optional, Type
67

8+
import numpy as np
9+
import sqlite_vec_sl_tmp
710
import structlog
811
from alembic import command as alembic_command
912
from alembic.config import Config as AlembicConfig
@@ -22,6 +25,9 @@
2225
IntermediatePromptWithOutputUsageAlerts,
2326
MuxRule,
2427
Output,
28+
Persona,
29+
PersonaDistance,
30+
PersonaEmbedding,
2531
Prompt,
2632
ProviderAuthMaterial,
2733
ProviderEndpoint,
@@ -65,7 +71,7 @@ def __new__(cls, *args, **kwargs):
6571
# It should only be used for testing
6672
if "_no_singleton" in kwargs and kwargs["_no_singleton"]:
6773
kwargs.pop("_no_singleton")
68-
return super().__new__(cls, *args, **kwargs)
74+
return super().__new__(cls)
6975

7076
if cls._instance is None:
7177
cls._instance = super().__new__(cls)
@@ -92,6 +98,22 @@ def __init__(self, sqlite_path: Optional[str] = None, **kwargs):
9298
}
9399
self._async_db_engine = create_async_engine(**engine_dict)
94100

101+
def _get_vec_db_connection(self):
102+
"""
103+
Vector database connection is a separate connection to the SQLite database. aiosqlite
104+
does not support loading extensions, so we need to use the sqlite3 module to load the
105+
vector extension.
106+
"""
107+
try:
108+
conn = sqlite3.connect(self._db_path)
109+
conn.enable_load_extension(True)
110+
sqlite_vec_sl_tmp.load(conn)
111+
conn.enable_load_extension(False)
112+
return conn
113+
except Exception:
114+
logger.exception("Failed to initialize vector database connection")
115+
raise
116+
95117
def does_db_exist(self):
96118
return self._db_path.is_file()
97119

@@ -523,6 +545,30 @@ async def add_mux(self, mux: MuxRule) -> MuxRule:
523545
added_mux = await self._execute_update_pydantic_model(mux, sql, should_raise=True)
524546
return added_mux
525547

548+
async def add_persona(self, persona: PersonaEmbedding) -> None:
549+
"""Add a new Persona to the DB.
550+
551+
This handles validation and insertion of a new persona.
552+
553+
It may raise a AlreadyExistsError if the persona already exists.
554+
"""
555+
sql = text(
556+
"""
557+
INSERT INTO personas (id, name, description, description_embedding)
558+
VALUES (:id, :name, :description, :description_embedding)
559+
"""
560+
)
561+
562+
try:
563+
# For Pydantic we conver the numpy array to a string when serializing.
564+
# We need to convert it back to a numpy array before inserting it into the DB.
565+
persona_dict = persona.model_dump()
566+
persona_dict["description_embedding"] = persona.description_embedding
567+
await self._execute_with_no_return(sql, persona_dict)
568+
except IntegrityError as e:
569+
logger.debug(f"Exception type: {type(e)}")
570+
raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")
571+
526572

527573
class DbReader(DbCodeGate):
528574
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
@@ -569,6 +615,18 @@ async def _exec_select_conditions_to_pydantic(
569615
raise e
570616
return None
571617

618+
async def _exec_vec_db_query(
619+
self, sql_command: str, conditions: dict
620+
) -> Optional[CursorResult]:
621+
"""
622+
Execute a query on the vector database. This is a separate connection to the SQLite
623+
database that has the vector extension loaded.
624+
"""
625+
conn = self._get_vec_db_connection()
626+
cursor = conn.cursor()
627+
cursor.execute(sql_command, conditions)
628+
return cursor
629+
572630
async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]:
573631
sql = text(
574632
"""
@@ -893,6 +951,49 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
893951
)
894952
return muxes
895953

954+
async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
955+
"""
956+
Get a persona by name.
957+
"""
958+
sql = text(
959+
"""
960+
SELECT
961+
id, name, description
962+
FROM personas
963+
WHERE name = :name
964+
"""
965+
)
966+
conditions = {"name": persona_name}
967+
personas = await self._exec_select_conditions_to_pydantic(
968+
Persona, sql, conditions, should_raise=True
969+
)
970+
return personas[0] if personas else None
971+
972+
async def get_distance_to_persona(
973+
self, persona_id: str, query_embedding: np.ndarray
974+
) -> PersonaDistance:
975+
"""
976+
Get the distance between a persona and a query embedding.
977+
"""
978+
sql = """
979+
SELECT
980+
id,
981+
name,
982+
description,
983+
vec_distance_cosine(description_embedding, :query_embedding) as distance
984+
FROM personas
985+
WHERE id = :id
986+
"""
987+
conditions = {"id": persona_id, "query_embedding": query_embedding}
988+
persona_distance_cursor = await self._exec_vec_db_query(sql, conditions)
989+
persona_distance_raw = persona_distance_cursor.fetchone()
990+
return PersonaDistance(
991+
id=persona_distance_raw[0],
992+
name=persona_distance_raw[1],
993+
description=persona_distance_raw[2],
994+
distance=persona_distance_raw[3],
995+
)
996+
896997

897998
def init_db_sync(db_path: Optional[str] = None):
898999
"""DB will be initialized in the constructor in case it doesn't exist."""

src/codegate/db/models.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from enum import Enum
33
from typing import Annotated, Any, Dict, List, Optional
44

5-
from pydantic import BaseModel, StringConstraints
5+
import numpy as np
6+
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints
67

78

89
class AlertSeverity(str, Enum):
@@ -240,3 +241,39 @@ class MuxRule(BaseModel):
240241
priority: int
241242
created_at: Optional[datetime.datetime] = None
242243
updated_at: Optional[datetime.datetime] = None
244+
245+
246+
# Pydantic doesn't support numpy arrays out of the box. Defining a custom type
247+
# Reference: https://github.com/pydantic/pydantic/issues/7017
248+
def nd_array_custom_before_validator(x):
249+
# custome before validation logic
250+
return x
251+
252+
253+
def nd_array_custom_serializer(x):
254+
# custome serialization logic
255+
return str(x)
256+
257+
258+
NdArray = Annotated[
259+
np.ndarray,
260+
BeforeValidator(nd_array_custom_before_validator),
261+
PlainSerializer(nd_array_custom_serializer, return_type=str),
262+
]
263+
264+
265+
class Persona(BaseModel):
266+
id: str
267+
name: str
268+
description: str
269+
270+
271+
class PersonaEmbedding(Persona):
272+
description_embedding: NdArray # sqlite-vec will handle numpy arrays directly
273+
274+
# Part of the workaround to allow numpy arrays in pydantic models
275+
model_config = ConfigDict(arbitrary_types_allowed=True)
276+
277+
278+
class PersonaDistance(Persona):
279+
distance: float
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import unicodedata
2+
import uuid
3+
4+
import numpy as np
5+
import regex as re
6+
import structlog
7+
8+
from codegate.config import Config
9+
from codegate.db import models as db_models
10+
from codegate.db.connection import DbReader, DbRecorder
11+
from codegate.inference.inference_engine import LlamaCppInferenceEngine
12+
13+
logger = structlog.get_logger("codegate")
14+
15+
16+
class PersonaDoesNotExistError(Exception):
17+
pass
18+
19+
20+
class SemanticRouter:
21+
22+
def __init__(self):
23+
self._inference_engine = LlamaCppInferenceEngine()
24+
conf = Config.get_config()
25+
self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}"
26+
self._n_gpu = conf.chat_model_n_gpu_layers
27+
self._persona_threshold = conf.persona_threshold
28+
self._db_recorder = DbRecorder()
29+
self._db_reader = DbReader()
30+
31+
def _clean_text_for_embedding(self, text: str) -> str:
32+
"""
33+
Clean the text for embedding. This function should be used to preprocess the text
34+
before embedding.
35+
36+
Performs the following operations:
37+
1. Replaces newlines and carriage returns with spaces
38+
2. Removes extra whitespace
39+
3. Converts to lowercase
40+
4. Removes URLs and email addresses
41+
5. Removes code block markers and other markdown syntax
42+
6. Normalizes Unicode characters
43+
7. Handles special characters and punctuation
44+
8. Normalizes numbers
45+
"""
46+
if not text:
47+
return ""
48+
49+
# Replace newlines and carriage returns with spaces
50+
text = text.replace("\n", " ").replace("\r", " ")
51+
52+
# Normalize Unicode characters (e.g., convert accented characters to ASCII equivalents)
53+
text = unicodedata.normalize("NFKD", text)
54+
text = "".join([c for c in text if not unicodedata.combining(c)])
55+
56+
# Remove URLs
57+
text = re.sub(r"https?://\S+|www\.\S+", " ", text)
58+
59+
# Remove email addresses
60+
text = re.sub(r"\S+@\S+", " ", text)
61+
62+
# Remove code block markers and other markdown/code syntax
63+
text = re.sub(r"```[\s\S]*?```", " ", text) # Code blocks
64+
text = re.sub(r"`[^`]*`", " ", text) # Inline code
65+
66+
# Remove HTML/XML tags
67+
text = re.sub(r"<[^>]+>", " ", text)
68+
69+
# Normalize numbers (replace with placeholder)
70+
text = re.sub(r"\b\d+\.\d+\b", " NUM ", text) # Decimal numbers
71+
text = re.sub(r"\b\d+\b", " NUM ", text) # Integer numbers
72+
73+
# Replace punctuation with spaces (keeping apostrophes for contractions)
74+
text = re.sub(r"[^\w\s\']", " ", text)
75+
76+
# Normalize whitespace (replace multiple spaces with a single space)
77+
text = re.sub(r"\s+", " ", text)
78+
79+
# Convert to lowercase and strip
80+
text = text.strip()
81+
82+
return text
83+
84+
async def _embed_text(self, text: str) -> np.ndarray:
85+
"""
86+
Helper function to embed text using the inference engine.
87+
"""
88+
cleaned_text = self._clean_text_for_embedding(text)
89+
# .embed returns a list of embeddings
90+
embed_list = await self._inference_engine.embed(
91+
self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu
92+
)
93+
# Use only the first entry in the list and make sure we have the appropriate type
94+
logger.debug("Text embedded in semantic routing", text=cleaned_text[:100])
95+
return np.array(embed_list[0], dtype=np.float32)
96+
97+
async def add_persona(self, persona_name: str, persona_desc: str) -> None:
98+
"""
99+
Add a new persona to the database. The persona description is embedded
100+
and stored in the database.
101+
"""
102+
emb_persona_desc = await self._embed_text(persona_desc)
103+
new_persona = db_models.PersonaEmbedding(
104+
id=str(uuid.uuid4()),
105+
name=persona_name,
106+
description=persona_desc,
107+
description_embedding=emb_persona_desc,
108+
)
109+
await self._db_recorder.add_persona(new_persona)
110+
logger.info(f"Added persona {persona_name} to the database.")
111+
112+
async def check_persona_match(self, persona_name: str, query: str) -> bool:
113+
"""
114+
Check if the query matches the persona description. A vector similarity
115+
search is performed between the query and the persona description.
116+
0 means the vectors are identical, 2 means they are orthogonal.
117+
See
118+
[sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)
119+
"""
120+
persona = await self._db_reader.get_persona_by_name(persona_name)
121+
if not persona:
122+
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")
123+
124+
emb_query = await self._embed_text(query)
125+
persona_distance = await self._db_reader.get_distance_to_persona(persona.id, emb_query)
126+
logger.info(f"Persona distance to {persona_name}", distance=persona_distance.distance)
127+
if persona_distance.distance < self._persona_threshold:
128+
return True
129+
return False

0 commit comments

Comments
 (0)