Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/middlewares/requestDataValidation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from fastapi import Depends, HTTPException
from fastapi.requests import Request
from validations.validation import ChatCompletionRequest
from pydantic import ValidationError
from typing import Dict, Any, List,Union

def get_human_readable_error(exc: Union[ValidationError, List[Dict[str, Any]]]) -> Dict[str, Any]:
"""
Convert validation errors to human-readable format.
Handles both Pydantic ValidationError and raw error lists.
"""
# Get errors list from either source
errors_list = exc.errors() if isinstance(exc, ValidationError) else exc

errors: List[Dict[str, Any]] = []

for error in errors_list:
loc = error.get('loc', ())
field = '.'.join(str(l) for l in loc if l != '__root__') or 'root'
msg = error.get('msg', 'Invalid value')
errors.append({
'message': msg,
'type': error.get('type', 'validation_error')
})

return {
"error": {
"message": "Validation failed",
"details": errors,
"suggestion": "Please check your input values"
}
}

async def validate_request_data(request: Request):
try:
# Validate request body against Pydantic model
body = await request.json()
validated_data = ChatCompletionRequest(**body) # Validate the body with the Pydantic model
return validated_data
except ValidationError as ve:
# If validation error occurs, format the errors in a human-readable way
error_response = get_human_readable_error(ve.errors())
raise HTTPException(status_code=400, detail=error_response)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid request data: {e}")
3 changes: 2 additions & 1 deletion src/routes/v2/modelRouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from config import Config
from src.services.commonServices.queueService.queueService import queue_obj
from src.middlewares.ratelimitMiddleware import rate_limit
from ...middlewares.requestDataValidation import validate_request_data
from globals import *


Expand All @@ -20,7 +21,7 @@ async def auth_and_rate_limit(request: Request):
await rate_limit(request,key_path='body.bridge_id' , points=100)
await rate_limit(request,key_path='body.thread_id', points=20)

@router.post('/chat/completion', dependencies=[Depends(auth_and_rate_limit)])
@router.post('/chat/completion', dependencies=[Depends(auth_and_rate_limit),Depends(validate_request_data)])
async def chat_completion(request: Request, db_config: dict = Depends(add_configuration_data_to_body)):
request.state.is_playground = False
request.state.version = 2
Expand Down
72 changes: 69 additions & 3 deletions validations/validation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from pydantic import BaseModel,Field,constr,conint,confloat,validator, ConfigDict
from typing import Optional, Dict,List,Any
from pydantic import BaseModel,Field,constr,conint,confloat,validator, ConfigDict,model_validator
from typing import Optional, Dict,List,Any,Literal

from typing_extensions import Annotated
from src.configs.model_configuration import model_config_document



# class ModelConfig(BaseModel):
# model: str
# creativity_level: Optional[float] = Field(None, ge=0, le=2)
Expand Down Expand Up @@ -49,4 +51,68 @@ class Bridge_update(BaseModel):
name: Optional[str] = None
apikey_object_id: Optional[str] = None
functionData: Optional[object]


class ChatCompletionRequest(BaseModel):
user: str = Field(..., description="User identifier")
bridge_id: str = Field(..., description="Bridge identifier")
variables: Dict[str, Any] = Field(default_factory=dict, description="Template variables")
model: Optional[str] = Field(None, description="Model name (required if configuration is provided)")
response_type: Optional[Literal["text", "json_object"]] = Field(None, description="Response format")
configuration: Optional[Dict[str, Any]] = Field(None, description="Model-specific configuration")

@model_validator(mode='before')
def validate_configuration(cls, data: Dict[str, Any]) -> Dict[str, Any]:
configuration = data.get("configuration")
model = data.get("model")

# Skip validation if no configuration or no model
if configuration is None or model is None:
return data

# Get model config (replace with your actual implementation)
model_config = model_config_document.get(model)
if not model_config:
raise ValueError(f"Model '{model}' not found in configuration document")

config_schema = model_config.get("configuration", {})

for field_name, field_schema in config_schema.items():
if field_name in configuration:
cls._validate_field(
field_name=field_name,
field_schema=field_schema,
user_value=configuration[field_name]
)

return data

@classmethod
def _validate_field(cls, field_name: str, field_schema: Dict[str, Any], user_value: Any):
"""Validate a single configuration field against its schema"""
field_type = field_schema.get("field")

if field_type == "slider":
if not isinstance(user_value, (int, float)):
raise ValueError(f"{field_name} must be a number")
if "min" in field_schema and user_value < field_schema["min"]:
raise ValueError(f"{field_name} must be ≥ {field_schema['min']}")
if "max" in field_schema and user_value > field_schema["max"]:
raise ValueError(f"{field_name} must be ≤ {field_schema['max']}")

elif field_type == "boolean":
if not isinstance(user_value, bool):
raise ValueError(f"{field_name} must be a boolean (True/False)")

elif field_type == "dropdown":
options = field_schema.get("options", [])
if user_value not in options:
raise ValueError(f"{field_name} must be one of: {options}")

elif field_type == "number":
if not isinstance(user_value, (int, float)):
raise ValueError(f"{field_name} must be a number")

elif field_type == "select":
options = [opt["type"] for opt in field_schema.get("options", [])]
if user_value not in options:
raise ValueError(f"{field_name} must be one of: {options}")