Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions app/auth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any, Dict
import httpx
import jwt
from fastapi import Depends, HTTPException, WebSocket, status
from fastapi import Depends, WebSocket, status
from fastapi.security import OAuth2AuthorizationCodeBearer
from jwt import PyJWKClient
from loguru import logger

from app.error import AuthException, DispatcherException
from app.schemas.websockets import WSStatusMessage

from .config.settings import settings

# Keycloak OIDC info
Expand Down Expand Up @@ -37,9 +40,9 @@ def _decode_token(token: str):
)
return payload
except Exception:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
raise AuthException(
http_status=status.HTTP_401_UNAUTHORIZED,
message="Could not validate credentials. Please retry signing in.",
)


Expand All @@ -55,6 +58,7 @@ async def websocket_authenticate(websocket: WebSocket) -> str | None:
"""
logger.debug("Authenticating websocket")
token = websocket.query_params.get("token")

if not token:
logger.error("Token is missing from websocket authentication")
await websocket.close(code=1008, reason="Missing token")
Expand All @@ -63,9 +67,22 @@ async def websocket_authenticate(websocket: WebSocket) -> str | None:
try:
await websocket.accept()
return token
except DispatcherException as ae:
logger.error(f"Dispatcher exception detected: {ae.message}")
await websocket.send_json(
WSStatusMessage(type="error", message=ae.message).model_dump()
)
await websocket.close(code=1008, reason=ae.error_code)
return None
except Exception as e:
logger.error(f"Invalid token in websocket authentication: {e}")
await websocket.close(code=1008, reason="Invalid token")
logger.error(f"Unexpected error occurred during websocket authentication: {e}")
await websocket.send_json(
WSStatusMessage(
type="error",
message="Something went wrong during authentication. Please try again.",
).model_dump()
)
await websocket.close(code=1008, reason="INTERNAL_ERROR")
return None


Expand All @@ -81,15 +98,15 @@ async def exchange_token_for_provider(

:return: The token response (dict) on success.

:raise: Raises HTTPException with an appropriate status and message on error.
:raise: Raises AuthException with an appropriate status and message on error.
"""
token_url = f"{KEYCLOAK_BASE_URL}/protocol/openid-connect/token"

# Check if the necessary settings are in place
if not settings.keycloak_client_id or not settings.keycloak_client_secret:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Token exchange not configured on the server (missing client credentials).",
raise AuthException(
http_status=status.HTTP_500_INTERNAL_SERVER_ERROR,
message="Token exchange not configured on the server (missing client credentials).",
)

payload = {
Expand All @@ -105,9 +122,12 @@ async def exchange_token_for_provider(
resp = await client.post(token_url, data=payload)
except httpx.RequestError as exc:
logger.error(f"Token exchange network error for provider={provider}: {exc}")
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Failed to contact the identity provider for token exchange.",
raise AuthException(
http_status=status.HTTP_502_BAD_GATEWAY,
message=(
f"Could not authenticate with {provider}. Please contact APEx support or reach out "
"through the <a href='https://forum.apex.esa.int/'>APEx User Forum</a>."
),
)

# Parse response
Expand All @@ -117,9 +137,12 @@ async def exchange_token_for_provider(
logger.error(
f"Token exchange invalid JSON response (status={resp.status_code})"
)
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail="Invalid response from identity provider during token exchange.",
raise AuthException(
http_status=status.HTTP_502_BAD_GATEWAY,
message=(
f"Could not authenticate with {provider}. Please contact APEx support or reach out "
"through the <a href='https://forum.apex.esa.int/'>APEx User Forum</a>."
),
)

if resp.status_code != 200:
Expand All @@ -136,7 +159,16 @@ async def exchange_token_for_provider(
else status.HTTP_502_BAD_GATEWAY
)

raise HTTPException(client_status, detail=body)
raise AuthException(
http_status=client_status,
message=(
f"Please link your account with {provider} in your "
"<a href='https://{settings.keycloak_host}/realms/{settings.keycloak_realm}/"
"account'>Account Dashboard</a>"
if body.get("error", "") == "not_linked"
else f"Could not authenticate with {provider}: {err}"
),
)

# Successful exchange, return token response (access_token, expires_in, etc.)
return body
2 changes: 1 addition & 1 deletion app/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_db():
yield db
db.commit()
except Exception:
logger.exception("An error occurred during database retrieval")
logger.error("An error occurred during database retrieval")
db.rollback()
raise
finally:
Expand Down
68 changes: 68 additions & 0 deletions app/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Dict, Optional
from fastapi import status
from pydantic import BaseModel


class ErrorResponse(BaseModel):
status: str = "error"
error_code: str
message: str
details: Optional[Dict[str, Any]] = None
request_id: Optional[str] = None


class DispatcherException(Exception):
"""
Base domain exception for the APEx Dispatch API.
"""

http_status: int = status.HTTP_400_BAD_REQUEST
error_code: str = "APEX_ERROR"
message: str = "An error occurred."
details: Optional[Dict[str, Any]] = None

def __init__(
self,
message: Optional[str] = None,
error_code: Optional[str] = None,
http_status: Optional[int] = None,
details: Optional[Dict[str, Any]] = None,
):
if message:
self.message = message
if error_code:
self.error_code = error_code
if http_status:
self.http_status = http_status
if details:
self.details = details

def __str__(self):
return f"{self.error_code}: {self.message}"


class AuthException(DispatcherException):
def __init__(
self,
http_status: Optional[int] = status.HTTP_401_UNAUTHORIZED,
message: Optional[str] = "Authentication failed.",
):
super().__init__(message, "AUTHENTICATION_FAILED", http_status)


class JobNotFoundException(DispatcherException):
http_status: int = status.HTTP_404_NOT_FOUND
error_code: str = "JOB_NOT_FOUND"
message: str = "The requested job was not found."


class TaskNotFoundException(DispatcherException):
http_status: int = status.HTTP_404_NOT_FOUND
error_code: str = "TASK_NOT_FOUND"
message: str = "The requested task was not found."


class InternalException(DispatcherException):
http_status: int = status.HTTP_500_INTERNAL_SERVER_ERROR
error_code: str = "INTERNAL_ERROR"
message: str = "An internal server error occurred."
2 changes: 2 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from fastapi.middleware.cors import CORSMiddleware

from app.middleware.correlation_id import add_correlation_id
from app.middleware.error_handling import register_exception_handlers
from app.platforms.dispatcher import load_processing_platforms
from app.services.tiles.base import load_grids
from app.config.logger import setup_logging
Expand All @@ -28,6 +29,7 @@
)

app.middleware("http")(add_correlation_id)
register_exception_handlers(app)

# include routers
app.include_router(tiles.router)
Expand Down
74 changes: 74 additions & 0 deletions app/middleware/error_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Any
from fastapi import Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.error import DispatcherException, ErrorResponse
from app.middleware.correlation_id import correlation_id_ctx
from loguru import logger


def get_dispatcher_error_response(
exc: DispatcherException, request_id: str
) -> ErrorResponse:
return ErrorResponse(
error_code=exc.error_code,
message=exc.message,
details=exc.details,
request_id=request_id,
)


async def dispatch_exception_handler(request: Request, exc: DispatcherException):

content = get_dispatcher_error_response(exc, correlation_id_ctx.get())
logger.exception(f"DispatcherException raised: {exc.message}")
return JSONResponse(status_code=exc.http_status, content=content.dict())


async def generic_exception_handler(request: Request, exc: Exception):

# DO NOT expose internal exceptions to the client
content = ErrorResponse(
error_code="INTERNAL_SERVER_ERROR",
message="An unexpected error occurred.",
details=None,
request_id=correlation_id_ctx.get(),
)

logger.exception(f"GenericException raised: {exc}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=content.dict()
)


def _parse_validation_error(err: Any):
if "ctx" in err:
del err["ctx"]
return err


async def validation_exception_handler(request: Request, exc: RequestValidationError):

logger.error(f"Request validation error: {exc.__class__.__name__}: {exc}")
content = ErrorResponse(
error_code="VALIDATION_ERROR",
message="Request validation failed.",
details={"errors": [_parse_validation_error(error) for error in exc.errors()]},
request_id=correlation_id_ctx.get(),
)

logger.error(content.dict())

return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content=content.dict()
)


def register_exception_handlers(app):
"""
Call this in main.py after creating the FastAPI() instance.
"""

app.add_exception_handler(DispatcherException, dispatch_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(Exception, generic_exception_handler)
69 changes: 52 additions & 17 deletions app/routers/jobs_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from loguru import logger

from app.database.db import SessionLocal, get_db
from app.error import DispatcherException, ErrorResponse, InternalException
from app.middleware.error_handling import get_dispatcher_error_response
from app.schemas.jobs_status import JobsFilter, JobsStatusResponse
from app.schemas.websockets import WSStatusMessage
from app.services.processing import get_processing_jobs_by_user_id
Expand All @@ -22,6 +24,19 @@
"/jobs_status",
tags=["Upscale Tasks", "Unit Jobs"],
summary="Get a list of all upscaling tasks & processing jobs for the authenticated user",
responses={
InternalException.http_status: {
"description": "Internal server error",
"model": ErrorResponse,
"content": {
"application/json": {
"example": get_dispatcher_error_response(
InternalException(), "request-id"
)
}
},
},
},
)
async def get_jobs_status(
db: Session = Depends(get_db),
Expand All @@ -34,21 +49,29 @@ async def get_jobs_status(
"""
Return combined list of upscaling tasks and processing jobs for the authenticated user.
"""
logger.debug("Fetching jobs list")
upscaling_tasks = (
await get_upscaling_tasks_by_user_id(token, db)
if JobsFilter.upscaling in filter
else []
)
processing_jobs = (
await get_processing_jobs_by_user_id(token, db)
if JobsFilter.processing in filter
else []
)
return JobsStatusResponse(
upscaling_tasks=upscaling_tasks,
processing_jobs=processing_jobs,
)
try:
logger.debug("Fetching jobs list")
upscaling_tasks = (
await get_upscaling_tasks_by_user_id(token, db)
if JobsFilter.upscaling in filter
else []
)
processing_jobs = (
await get_processing_jobs_by_user_id(token, db)
if JobsFilter.processing in filter
else []
)
return JobsStatusResponse(
upscaling_tasks=upscaling_tasks,
processing_jobs=processing_jobs,
)
except DispatcherException as de:
raise de
except Exception as e:
logger.error(f"Error retrieving job status: {e}")
raise InternalException(
message="An error occurred while retrieving the job status."
)


@router.websocket(
Expand Down Expand Up @@ -91,8 +114,20 @@ async def ws_jobs_status(

except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except DispatcherException as ae:
logger.error(f"Dispatcher exception detected: {ae.message}")
await websocket.send_json(
WSStatusMessage(type="error", message=ae.message).model_dump()
)
await websocket.close(code=1011, reason=ae.error_code)
except Exception as e:
logger.exception(f"Error in jobs_status_ws: {e}")
await websocket.close(code=1011, reason="Error in job status websocket: {e}")
logger.error(f"Unexpected error occurred during websocket : {e}")
await websocket.send_json(
WSStatusMessage(
type="error",
message="An error occurred while monitoring the job status.",
).model_dump()
)
await websocket.close(code=1011, reason="INTERNAL_ERROR")
finally:
db.close()
Loading