Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 82 additions & 22 deletions hyperquant/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,78 @@
DEFAULT_MAX_HTTP_BODY_OVERHEAD_BYTES = 1024 * 1024


class _BodySizeLimitMiddleware:
def __init__(self, app, *, max_http_body_bytes: int, metrics: HyperQuantMetrics) -> None:
self.app = app
self.max_http_body_bytes = int(max_http_body_bytes)
self.metrics = metrics

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return

body_too_large_detail = f"request body exceeds max_http_body_bytes={self.max_http_body_bytes}"
for key, value in scope.get("headers", []):
if key.lower() != b"content-length":
continue
try:
declared = int(value.decode("latin1"))
except ValueError:
declared = None
if declared is not None and declared > self.max_http_body_bytes:
self.metrics.observe_error("http", "request_too_large")
await JSONResponse(status_code=413, content={"detail": body_too_large_detail})(scope, receive, send)
return

seen = 0
too_large = False
sent_too_large_response = False
request_stream_ended = False

async def drain_remaining_body() -> None:
nonlocal request_stream_ended
while True:
message = await receive()
if message.get("type") != "http.request":
return
request_stream_ended = not message.get("more_body", False)
if not message.get("more_body", False):
return

async def guarded_receive():
nonlocal seen, too_large, request_stream_ended
message = await receive()
if too_large:
return {"type": "http.request", "body": b"", "more_body": False}
if message.get("type") == "http.request":
request_stream_ended = not message.get("more_body", False)
seen += len(message.get("body", b""))
if seen > self.max_http_body_bytes:
too_large = True
return {"type": "http.request", "body": b"", "more_body": False}
return message

async def guarded_send(message):
nonlocal sent_too_large_response
if too_large:
if not sent_too_large_response:
sent_too_large_response = True
if not request_stream_ended:
await drain_remaining_body()
self.metrics.observe_error("http", "request_too_large")
await JSONResponse(status_code=413, content={"detail": body_too_large_detail})(scope, receive, send)
return
await send(message)

await self.app(scope, guarded_receive, guarded_send)
if too_large and not sent_too_large_response:
if not request_stream_ended:
await drain_remaining_body()
self.metrics.observe_error("http", "request_too_large")
await JSONResponse(status_code=413, content={"detail": body_too_large_detail})(scope, receive, send)


def _pydantic_model_to_dict(model) -> dict:
if hasattr(model, "model_dump"):
return model.model_dump()
Expand Down Expand Up @@ -148,22 +220,10 @@ async def run_bound(fn):
app.state.max_request_bytes = max_request_bytes
app.state.max_http_body_bytes = max_http_body_bytes
app.state.max_concurrency = resolved_max_concurrency
app.add_middleware(_BodySizeLimitMiddleware, max_http_body_bytes=max_http_body_bytes, metrics=metrics)

@app.middleware("http")
async def enforce_content_length(request, call_next):
content_length = request.headers.get("content-length")
if content_length is not None:
try:
size = int(content_length)
except ValueError:
size = None
if size is not None and size > max_http_body_bytes:
metrics.observe_error("http", "request_too_large")
return JSONResponse(
status_code=413,
content={"detail": f"request body exceeds max_http_body_bytes={max_http_body_bytes}"},
)
return await call_next(request)
def internal_server_error(_exc: Exception) -> HTTPException:
return HTTPException(status_code=500, detail="internal server error")

@app.get("/healthz", response_model=HealthResponse)
async def healthz() -> HealthResponse:
Expand Down Expand Up @@ -205,7 +265,7 @@ def do_compress():
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # pragma: no cover - FastAPI behavior tested through endpoint
metrics.observe_error("compress", "internal_error")
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise internal_server_error(exc) from exc
metrics.observe_compress(stats, latency_seconds=time.perf_counter() - started)
return CodebookCompressResponse(
envelope_b64=envelope.to_base64(),
Expand All @@ -227,7 +287,7 @@ def do_decompress():
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # pragma: no cover
metrics.observe_error("decompress", "internal_error")
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise internal_server_error(exc) from exc
metrics.observe_decompress(latency_seconds=time.perf_counter() - started)
return DecompressResponse(array_b64=ndarray_to_b64(restored))

Expand All @@ -247,7 +307,7 @@ def do_vector_compress():
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # pragma: no cover
metrics.observe_error("vector_compress", "internal_error")
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise internal_server_error(exc) from exc
metrics.observe_vector_compress(stats, latency_seconds=time.perf_counter() - started)
return VectorCompressResponse(
envelope_b64=envelope.to_base64(),
Expand All @@ -270,7 +330,7 @@ def do_vector_decompress():
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # pragma: no cover
metrics.observe_error("vector_decompress", "internal_error")
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise internal_server_error(exc) from exc
metrics.observe_vector_decompress(latency_seconds=time.perf_counter() - started)
return DecompressResponse(array_b64=ndarray_to_b64(restored))

Expand Down Expand Up @@ -316,7 +376,7 @@ def do_resident_plan():
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # pragma: no cover
metrics.observe_error("resident_plan", "internal_error")
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise internal_server_error(exc) from exc
metrics.observe_resident_plan(plan, latency_seconds=time.perf_counter() - started)
return ResidentPlanResponse(plan=plan.to_dict())

Expand Down Expand Up @@ -370,7 +430,7 @@ def do_context_compress():
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # pragma: no cover
metrics.observe_error("context_compress", "internal_error")
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise internal_server_error(exc) from exc
metrics.observe_context_compress(stats, latency_seconds=time.perf_counter() - started)
return ContextCompressResponse(
envelope_b64=envelope.to_base64(),
Expand All @@ -395,7 +455,7 @@ def do_context_decompress():
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # pragma: no cover
metrics.observe_error("context_decompress", "internal_error")
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise internal_server_error(exc) from exc
metrics.observe_context_decompress(latency_seconds=time.perf_counter() - started)
return DecompressResponse(array_b64=ndarray_to_b64(restored))

Expand Down
103 changes: 64 additions & 39 deletions hyperquant/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@

from pydantic import BaseModel, Field

from ..defaults import (
CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT,
CONTEXT_ENABLE_PAGE_REF_DEFAULT,
CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT,
CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT,
CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT,
CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT,
CONTEXT_PAGE_SIZE_DEFAULT,
CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT,
CONTEXT_RANK_DEFAULT,
CONTEXT_REF_ROUND_DECIMALS_DEFAULT,
CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT,
CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT,
RESIDENT_ACTIVE_WINDOW_TOKENS_DEFAULT,
RESIDENT_ALLOW_VECTOR_FOR_PROTECTED_DEFAULT,
RESIDENT_CONCURRENT_SESSIONS_DEFAULT,
RESIDENT_HOT_PAGES_DEFAULT,
RESIDENT_RUNTIME_VALUE_BYTES_DEFAULT,
VECTOR_BITS_DEFAULT,
VECTOR_GROUP_SIZE_DEFAULT,
VECTOR_PREFER_NATIVE_FWHT_DEFAULT,
VECTOR_RESIDUAL_TOPK_DEFAULT,
VECTOR_ROTATION_SEED_DEFAULT,
)


class CodebookCompressRequest(BaseModel):
array_b64: str = Field(..., description="Base64-encoded .npy payload.")
Expand All @@ -30,11 +55,11 @@ class DecompressRequest(BaseModel):

class VectorCompressRequest(BaseModel):
array_b64: str = Field(..., description="Base64-encoded .npy payload.")
bits: int = Field(default=3, ge=2, le=4)
group_size: int = Field(default=128, gt=0)
rotation_seed: int = Field(default=17)
residual_topk: int = Field(default=1, ge=0)
prefer_native_fwht: bool = True
bits: int = Field(default=VECTOR_BITS_DEFAULT, ge=2, le=4)
group_size: int = Field(default=VECTOR_GROUP_SIZE_DEFAULT, gt=0)
rotation_seed: int = Field(default=VECTOR_ROTATION_SEED_DEFAULT)
residual_topk: int = Field(default=VECTOR_RESIDUAL_TOPK_DEFAULT, ge=0)
prefer_native_fwht: bool = VECTOR_PREFER_NATIVE_FWHT_DEFAULT


class ContextGuaranteeModel(BaseModel):
Expand All @@ -47,18 +72,18 @@ class ContextGuaranteeModel(BaseModel):
class ContextCompressRequest(BaseModel):
array_b64: str = Field(..., description="Base64-encoded .npy payload.")
protected_vector_indices: List[int] = Field(default_factory=list)
page_size: int = Field(default=64, gt=0)
rank: int = Field(default=1, gt=0)
prefix_keep_vectors: int = Field(default=32, ge=0)
suffix_keep_vectors: int = Field(default=64, ge=0)
low_rank_error_threshold: float = Field(default=0.03, ge=0.0)
ref_round_decimals: int = Field(default=3, ge=0)
enable_page_ref: bool = True
page_ref_rel_rms_threshold: float = Field(default=0.005, ge=0.0)
enable_int8_fallback: bool = True
try_int8_for_protected: bool = True
int8_rel_rms_threshold: float = Field(default=0.01, ge=0.0)
int8_max_abs_threshold: float = Field(default=0.05, ge=0.0)
page_size: int = Field(default=CONTEXT_PAGE_SIZE_DEFAULT, gt=0)
rank: int = Field(default=CONTEXT_RANK_DEFAULT, gt=0)
prefix_keep_vectors: int = Field(default=CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT, ge=0)
suffix_keep_vectors: int = Field(default=CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT, ge=0)
low_rank_error_threshold: float = Field(default=CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT, ge=0.0)
ref_round_decimals: int = Field(default=CONTEXT_REF_ROUND_DECIMALS_DEFAULT, ge=0)
enable_page_ref: bool = CONTEXT_ENABLE_PAGE_REF_DEFAULT
page_ref_rel_rms_threshold: float = Field(default=CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT, ge=0.0)
enable_int8_fallback: bool = CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT
try_int8_for_protected: bool = CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT
int8_rel_rms_threshold: float = Field(default=CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT, ge=0.0)
int8_max_abs_threshold: float = Field(default=CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT, ge=0.0)
fail_closed: bool = True
guarantee: ContextGuaranteeModel | None = None

Expand Down Expand Up @@ -120,29 +145,29 @@ class ContextCompressionStatsModel(BaseModel):

class ResidentPlanRequest(BaseModel):
array_b64: str = Field(..., description="Base64-encoded .npy payload.")
concurrent_sessions: int = Field(default=8, gt=0)
active_window_tokens: int = Field(default=256, gt=0)
runtime_value_bytes: int = Field(default=2, gt=0)
concurrent_sessions: int = Field(default=RESIDENT_CONCURRENT_SESSIONS_DEFAULT, gt=0)
active_window_tokens: int = Field(default=RESIDENT_ACTIVE_WINDOW_TOKENS_DEFAULT, gt=0)
runtime_value_bytes: int = Field(default=RESIDENT_RUNTIME_VALUE_BYTES_DEFAULT, gt=0)
budget_bytes: int | None = Field(default=None, gt=0)
page_size: int = Field(default=64, gt=0)
rank: int = Field(default=1, gt=0)
bits: int = Field(default=3, ge=2, le=4)
group_size: int = Field(default=128, gt=0)
hot_pages: int = Field(default=8, gt=0)
rotation_seed: int = Field(default=17)
residual_topk: int = Field(default=1, ge=0)
prefix_keep_vectors: int = Field(default=32, ge=0)
suffix_keep_vectors: int = Field(default=64, ge=0)
low_rank_error_threshold: float = Field(default=0.03, ge=0.0)
ref_round_decimals: int = Field(default=3, ge=0)
enable_page_ref: bool = True
page_ref_rel_rms_threshold: float = Field(default=0.005, ge=0.0)
enable_int8_fallback: bool = True
try_int8_for_protected: bool = True
int8_rel_rms_threshold: float = Field(default=0.01, ge=0.0)
int8_max_abs_threshold: float = Field(default=0.05, ge=0.0)
prefer_native_fwht: bool = True
allow_vector_for_protected: bool = False
page_size: int = Field(default=CONTEXT_PAGE_SIZE_DEFAULT, gt=0)
rank: int = Field(default=CONTEXT_RANK_DEFAULT, gt=0)
bits: int = Field(default=VECTOR_BITS_DEFAULT, ge=2, le=4)
group_size: int = Field(default=VECTOR_GROUP_SIZE_DEFAULT, gt=0)
hot_pages: int = Field(default=RESIDENT_HOT_PAGES_DEFAULT, gt=0)
rotation_seed: int = Field(default=VECTOR_ROTATION_SEED_DEFAULT)
residual_topk: int = Field(default=VECTOR_RESIDUAL_TOPK_DEFAULT, ge=0)
prefix_keep_vectors: int = Field(default=CONTEXT_PREFIX_KEEP_VECTORS_DEFAULT, ge=0)
suffix_keep_vectors: int = Field(default=CONTEXT_SUFFIX_KEEP_VECTORS_DEFAULT, ge=0)
low_rank_error_threshold: float = Field(default=CONTEXT_LOW_RANK_ERROR_THRESHOLD_DEFAULT, ge=0.0)
ref_round_decimals: int = Field(default=CONTEXT_REF_ROUND_DECIMALS_DEFAULT, ge=0)
enable_page_ref: bool = CONTEXT_ENABLE_PAGE_REF_DEFAULT
page_ref_rel_rms_threshold: float = Field(default=CONTEXT_PAGE_REF_REL_RMS_THRESHOLD_DEFAULT, ge=0.0)
enable_int8_fallback: bool = CONTEXT_ENABLE_INT8_FALLBACK_DEFAULT
try_int8_for_protected: bool = CONTEXT_TRY_INT8_FOR_PROTECTED_DEFAULT
int8_rel_rms_threshold: float = Field(default=CONTEXT_INT8_REL_RMS_THRESHOLD_DEFAULT, ge=0.0)
int8_max_abs_threshold: float = Field(default=CONTEXT_INT8_MAX_ABS_THRESHOLD_DEFAULT, ge=0.0)
prefer_native_fwht: bool = VECTOR_PREFER_NATIVE_FWHT_DEFAULT
allow_vector_for_protected: bool = RESIDENT_ALLOW_VECTOR_FOR_PROTECTED_DEFAULT


class ResidentPlanResponse(BaseModel):
Expand Down
Loading
Loading