Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions backend/apps/cas_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import html
import logging
from http import HTTPStatus
from typing import Optional
from urllib.parse import parse_qs

from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse

from services.cas_service import (
CasAuthenticationError,
build_login_url,
build_renew_url,
get_cas_config,
login_with_ticket,
renew_with_ticket,
revoke_from_logout_request,
)

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/user/cas", tags=["cas"])


@router.get("/config")
async def config():
return JSONResponse(
status_code=HTTPStatus.OK,
content={"message": "success", "data": get_cas_config()},
)


@router.get("/login")
async def login(redirect: str = Query("/", description="URL to return to after login")):

Check warning on line 33 in backend/apps/cas_app.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use "Annotated" type hints for FastAPI dependency injection

See more on https://sonarcloud.io/project/issues?id=ModelEngine-Group_nexent&issues=AZ5uoE3Fzpp-eGzd4BWe&open=AZ5uoE3Fzpp-eGzd4BWe&pullRequest=3072
try:
return RedirectResponse(url=build_login_url(redirect), status_code=HTTPStatus.FOUND)
except CasAuthenticationError as exc:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(exc))


@router.get("/callback")
async def callback(ticket: str = "", redirect: str = "/"):
try:
result = await login_with_ticket(ticket, redirect)
return JSONResponse(
status_code=HTTPStatus.OK,
content={"message": "CAS login successful", "data": result},
)
except CasAuthenticationError as exc:
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=str(exc))
except Exception as exc:
logger.error(f"CAS callback failed: {exc}")

Check failure on line 51 in backend/apps/cas_app.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use "logging.exception()" instead.

See more on https://sonarcloud.io/project/issues?id=ModelEngine-Group_nexent&issues=AZ5uoE3Fzpp-eGzd4BWf&open=AZ5uoE3Fzpp-eGzd4BWf&pullRequest=3072
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="CAS login failed")


@router.post("/callback")
async def callback_logout(request: Request, logout_request: Optional[str] = None):
return await _handle_logout_request(request, logout_request, endpoint="callback")


@router.get("/renew")
async def renew():
try:
return RedirectResponse(url=build_renew_url(), status_code=HTTPStatus.FOUND)
except CasAuthenticationError as exc:
return _renew_html(False, str(exc))


@router.get("/renew_callback")
async def renew_callback(ticket: str = ""):
if not ticket:
return _renew_html(False, "CAS session is not active")
try:
result = await renew_with_ticket(ticket)
return JSONResponse(
status_code=HTTPStatus.OK,
content={"message": "CAS renew successful", "data": result},
)
except Exception as exc:
logger.warning(f"CAS renew failed: {exc}")
return _renew_html(False, "CAS renew failed")


@router.post("/logout_callback")
async def logout_callback(
request: Request,
logout_request: Optional[str] = None,
):
return await _handle_logout_request(request, logout_request, endpoint="logout_callback")


async def _handle_logout_request(
request: Request,
logout_request: Optional[str] = None,
endpoint: str = "unknown",
):
logout_request = await _extract_logout_request(request, logout_request)
logger.info(
"CAS SLO %s received logoutRequest: present=%s length=%s",
endpoint,
bool(logout_request),
len(logout_request or ""),
)
result = revoke_from_logout_request(logout_request)
logger.info("CAS SLO %s revoke result: %s", endpoint, result)
return JSONResponse(
status_code=HTTPStatus.OK,
content={"message": "success", "data": result},
)


async def _extract_logout_request(request: Request, logout_request: Optional[str] = None) -> str:
if logout_request:
return logout_request

query_logout_request = request.query_params.get("logoutRequest") or request.query_params.get("logout_request")
if query_logout_request:
return query_logout_request

body = await request.body()
raw_body = body.decode("utf-8") if body else ""
if not raw_body:
return ""

parsed = parse_qs(raw_body)
return (parsed.get("logoutRequest") or parsed.get("logout_request") or [raw_body])[0]


def _renew_html(success: bool, reason: str = "") -> HTMLResponse:
status = "success" if success else "failed"
safe_reason = html.escape(reason)
return HTMLResponse(
status_code=HTTPStatus.OK,
content=f"""<!doctype html>
<html><body><script>
window.parent && window.parent.postMessage({{ type: "cas-renew-{status}", reason: "{safe_reason}" }}, window.location.origin);
</script></body></html>""",
Comment on lines +133 to +136
)
2 changes: 2 additions & 0 deletions backend/apps/config_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from apps.monitoring_app import router as monitoring_router
from apps.a2a_server_app import router as a2a_server_router
from apps.haotian_app import router as haotian_router
from apps.cas_app import router as cas_router
from consts.const import IS_SPEED_MODE
from services.prompt_template_service import sync_system_default_prompt_template

Expand Down Expand Up @@ -73,6 +74,7 @@ async def sync_default_prompt_template_on_startup():
app.include_router(user_management_router)

app.include_router(oauth_router)
app.include_router(cas_router)

app.include_router(summary_router)
app.include_router(prompt_router)
Expand Down
25 changes: 23 additions & 2 deletions backend/apps/user_management_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
ValidationError,
)
from consts.error_code import ErrorCode
from services.cas_service import build_logout_url, CasAuthenticationError
from services.user_management_service import get_authorized_client, validate_token, \
check_auth_service_health, signup_user_with_invitation, signin_user, refresh_user_token, \
get_session_by_authorization, get_user_info, create_token, list_tokens_by_user, delete_token, \
update_password
from services.user_service import delete_user_and_cleanup
from utils.auth_utils import get_current_user_id
from utils.auth_utils import get_current_user_id, extract_session_id_from_authorization


load_dotenv()
Expand Down Expand Up @@ -144,7 +145,18 @@ async def logout(request: Request):
authorization = request.headers.get("Authorization")
try:
# Make logout idempotent: if no token or token expired, still return success
session_id = None
cas_logout_url = ""
if authorization:
session_id = extract_session_id_from_authorization(authorization)
if session_id:
from database.cas_session_db import revoke_cas_session_by_session_id

revoke_cas_session_by_session_id(session_id, actor="user")
try:
cas_logout_url = build_logout_url()
except CasAuthenticationError as cas_err:
logging.warning(f"CAS logout URL is unavailable: {str(cas_err)}")
client = get_authorized_client(authorization)
try:
client.auth.sign_out()
Expand All @@ -153,7 +165,12 @@ async def logout(request: Request):
logging.warning(
f"Sign out encountered an error but will be ignored: {str(signout_err)}")
return JSONResponse(status_code=HTTPStatus.OK,
content={"message": "Logout successful"})
content={
"message": "Logout successful",
"data": {
"cas_logout_url": cas_logout_url
}
})

except Exception as e:
logging.error(f"User logout failed: {str(e)}")
Expand Down Expand Up @@ -208,6 +225,10 @@ async def get_user_information(request: Request):
if not user_info:
raise UnauthorizedError("User information not found")

user_info["user"]["auth_provider"] = (
"cas" if extract_session_id_from_authorization(authorization) else "local"
)

return JSONResponse(status_code=HTTPStatus.OK,
content={"message": "Success",
"data": user_info})
Expand Down
25 changes: 25 additions & 0 deletions backend/consts/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,31 @@ class VectorDatabaseType(str, Enum):
OAUTH_CA_BUNDLE = os.getenv("OAUTH_CA_BUNDLE", "")


# CAS SSO Configuration
CAS_ENABLED = os.getenv("CAS_ENABLED", "false").lower() in ("true", "1", "yes", "on")
CAS_SERVER_URL = os.getenv("CAS_SERVER_URL", "").rstrip("/")
CAS_VALIDATE_PATH = os.getenv("CAS_VALIDATE_PATH", "/p3/serviceValidate")
CAS_CALLBACK_BASE_URL = os.getenv("CAS_CALLBACK_BASE_URL", OAUTH_CALLBACK_BASE_URL).rstrip("/")
# CAS login mode:
# - disabled: disable CAS login entry and automatic CAS redirects.
# - button: show CAS as an optional login entry.
# - force: automatically redirect unauthenticated users to CAS login.
CAS_LOGIN_MODE = os.getenv("CAS_LOGIN_MODE", "disabled").lower()
CAS_USER_ATTRIBUTE = os.getenv("CAS_USER_ATTRIBUTE", "")
CAS_EMAIL_ATTRIBUTE = os.getenv("CAS_EMAIL_ATTRIBUTE", "email")
CAS_ROLE_ATTRIBUTE = os.getenv("CAS_ROLE_ATTRIBUTE", "role")
CAS_TENANT_ATTRIBUTE = os.getenv("CAS_TENANT_ATTRIBUTE", "tenant_id")
CAS_ROLE_MAP_JSON = os.getenv("CAS_ROLE_MAP_JSON", "")
CAS_SESSION_MAX_AGE_SECONDS = int(os.getenv("CAS_SESSION_MAX_AGE_SECONDS", "3600") or 3600)
LOCAL_SESSION_MAX_AGE_SECONDS = int(os.getenv("LOCAL_SESSION_MAX_AGE_SECONDS", "3600") or 3600)
CAS_RENEW_BEFORE_SECONDS = int(os.getenv("CAS_RENEW_BEFORE_SECONDS", "300") or 300)
CAS_RENEW_TIMEOUT_SECONDS = int(os.getenv("CAS_RENEW_TIMEOUT_SECONDS", "10") or 10)
CAS_SYNTHETIC_EMAIL_DOMAIN = os.getenv("CAS_SYNTHETIC_EMAIL_DOMAIN", "cas.local")
CAS_LOGOUT_URL = os.getenv("CAS_LOGOUT_URL", "")
CAS_SSL_VERIFY = os.getenv("CAS_SSL_VERIFY", "true").lower() == "true"
CAS_CA_BUNDLE = os.getenv("CAS_CA_BUNDLE", "")


# ===== To be migrated to frontend configuration =====
# Email Configuration
IMAP_SERVER = os.getenv('IMAP_SERVER')
Expand Down
134 changes: 134 additions & 0 deletions backend/database/cas_session_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Database operations for CAS-backed web sessions.
"""

from datetime import datetime
from typing import Any, Dict, Optional

from database.client import as_dict, get_db_session
from database.db_models import UserCasSession

CAS_SESSION_ACTIVE = "active"
CAS_SESSION_REVOKED = "revoked"


def create_cas_session(
*,
session_id: str,
user_id: str,
cas_user_id: str,
expires_at: datetime,
cas_session_index: Optional[str] = None,
) -> Dict[str, Any]:
with get_db_session() as session:
record = UserCasSession(
session_id=session_id,
user_id=user_id,
cas_user_id=cas_user_id,
cas_session_index=cas_session_index,
status=CAS_SESSION_ACTIVE,
expires_at=expires_at,
created_by=user_id,
updated_by=user_id,
)
session.add(record)
session.flush()
return as_dict(record)


def get_cas_session_by_session_id(session_id: str) -> Optional[Dict[str, Any]]:
if not session_id:
return None
with get_db_session() as session:
result = (
session.query(UserCasSession)
.filter(
UserCasSession.session_id == session_id,
UserCasSession.delete_flag == "N",
)
.first()
)
return as_dict(result) if result else None


def is_cas_session_active(session_id: str) -> bool:
if not session_id:
return False
with get_db_session() as session:
result = (
session.query(UserCasSession)
.filter(
UserCasSession.session_id == session_id,
UserCasSession.status == CAS_SESSION_ACTIVE,
UserCasSession.expires_at > datetime.now(),
UserCasSession.delete_flag == "N",
)
.first()
)
return result is not None


def revoke_cas_session_by_session_id(session_id: str, actor: str = "cas") -> int:
if not session_id:
return 0
with get_db_session() as session:
result = (
session.query(UserCasSession)
.filter(
UserCasSession.session_id == session_id,
UserCasSession.status == CAS_SESSION_ACTIVE,
UserCasSession.delete_flag == "N",
)
.update(
{
"status": CAS_SESSION_REVOKED,
"revoked_at": datetime.now(),
"updated_by": actor,
}
)
)
return result


def revoke_cas_sessions_by_user_id(cas_user_id: str, actor: str = "cas") -> int:
if not cas_user_id:
return 0
with get_db_session() as session:
result = (
session.query(UserCasSession)
.filter(
UserCasSession.cas_user_id == cas_user_id,
UserCasSession.status == CAS_SESSION_ACTIVE,
UserCasSession.delete_flag == "N",
)
.update(
{
"status": CAS_SESSION_REVOKED,
"revoked_at": datetime.now(),
"updated_by": actor,
}
)
)
return result


def revoke_cas_session_by_index(cas_session_index: str, actor: str = "cas") -> int:
if not cas_session_index:
return 0
with get_db_session() as session:
result = (
session.query(UserCasSession)
.filter(
UserCasSession.cas_session_index == cas_session_index,
UserCasSession.status == CAS_SESSION_ACTIVE,
UserCasSession.delete_flag == "N",
)
.update(
{
"status": CAS_SESSION_REVOKED,
"revoked_at": datetime.now(),
"updated_by": actor,
}
)
)
return result
Loading