Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions api/src/middleware/request_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import base64
import hashlib
import hmac
import json
import logging
import re
from contextvars import ContextVar
from typing import Optional

import requests
from google.auth import jwt
Expand Down Expand Up @@ -55,6 +61,68 @@ def decode_jwt(self, token: str):
logging.error("Error decoding JWT: %s", e)
return None

def decode_user_context_jwt(self, token: str):
"""Decode and verify the custom user-context JWT sent by the web app.

This token is signed with HS256 using a shared secret (S2S_JWT_SECRET).
If verification fails for any reason, None is returned and the request
falls back to the existing IAP / Authorization-based identity handling.
"""
try:
secret = get_config("S2S_JWT_SECRET")
if not secret or len(secret) < 32:
# Misconfiguration: do not fail the request, just skip user-context.
logging.error(
"S2S_JWT_SECRET is missing or too short; " "cannot verify x-mdb-user-context token.",
)
return None

token = token.replace("Bearer ", "")
parts = token.split(".")
if len(parts) != 3:
return None

header_b64, payload_b64, signature_b64 = parts
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")

expected_sig = hmac.new(secret.encode("utf-8"), signing_input, hashlib.sha256).digest()

# JWT uses URL-safe base64 without padding
def b64url_decode(value: str) -> bytes:
padding = "=" * (-len(value) % 4)
return base64.urlsafe_b64decode(value + padding)

actual_sig = b64url_decode(signature_b64)
if not hmac.compare_digest(expected_sig, actual_sig):
logging.warning("Invalid signature for x-mdb-user-context token")
return None

payload_json = b64url_decode(payload_b64).decode("utf-8")
payload = json.loads(payload_json)
# Minimal shape we care about: { uid, email?, isGuest? }
if not isinstance(payload, dict) or "uid" not in payload:
return None
return payload
except Exception as e: # pragma: no cover - defensive
logging.error("Error decoding user-context JWT: %s", e)
return None

@staticmethod
def extract_user_id(raw_user_id: Optional[str]) -> Optional[str]:
"""
Extracts the user ID from the raw user ID string.
- If there is a colon, return the substring after the last colon.
- If there is no colon, return the original raw_user_id.
- If raw_user_id is None, return None.
"""
if raw_user_id is None:
return None

match = re.search(r":([^:]+)$", raw_user_id)
if match:
return match.group(1)
return raw_user_id

def _extract_from_headers(self, headers: dict, scope: Scope) -> None:
self.host = headers.get("host")
self.protocol = headers.get("x-forwarded-proto") if headers.get("x-forwarded-proto") else scope.get("scheme")
Expand Down Expand Up @@ -87,13 +155,29 @@ def _extract_from_headers(self, headers: dict, scope: Scope) -> None:
# auth header is used for local development
self.user_id = headers.get("x-goog-authenticated-user-id")
self.user_email = headers.get("x-goog-authenticated-user-email")
self.is_guest = False
self.google_public_keys = None
if not self.iap_jwt_assertion and headers.get("authorization"):
self.iap_jwt_assertion = self.decode_jwt(headers.get("authorization"))
if self.iap_jwt_assertion:
self.user_id = self.iap_jwt_assertion.get("user_id")
self.user_email = self.iap_jwt_assertion.get("email")

# Optional user-context header set by the web app for server-to-server calls.
# Name is aligned with the frontend's USER_CONTEXT_HEADER.
user_context_header = headers.get("x-mdb-user-context") or headers.get("md-user-context")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[question] where is md-user-context used?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not currently in use. It will be used for server-to-server calls that are not from the UI.

if user_context_header:
user_context = self.decode_user_context_jwt(user_context_header)
if user_context:
# Prefer values from the verified user-context token when present.
self.user_id = user_context.get("uid", self.user_id)
self.user_email = user_context.get("email", self.user_email)
self.is_guest = bool(user_context.get("isGuest"))
# if the user_id is in the format "accounts.google.com:1234567890",
# extract just the numeric ID part for consistency with legacy IAP user_id format
if self.user_id:
self.user_id = RequestContext.extract_user_id(self.user_id)

def __repr__(self) -> str:
# Omitting sensitive data like email and jwt assertion
safe_properties = dict(
Expand Down
15 changes: 15 additions & 0 deletions api/tests/unittest/middleware/test_request_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from unittest.mock import MagicMock

import pytest
from starlette.datastructures import Headers

from middleware.request_context import RequestContext, get_request_context, _request_context
Expand Down Expand Up @@ -35,6 +36,7 @@ def test_init_extract_headers(self):
"client_host": "client",
"client_user_agent": "user-agent",
"google_public_keys": None,
"is_guest": False,
"headers": Headers(scope=scope_instance),
"host": "localhost",
"iap_jwt_assertion": "jwt",
Expand All @@ -54,3 +56,16 @@ def test_get_request_context(self):
request_context = RequestContext(MagicMock())
_request_context.set(request_context)
self.assertEqual(request_context, get_request_context())


@pytest.mark.parametrize(
"raw_user_id, expected",
[
(None, None),
("plainuserid", "plainuserid"),
("accounts.google.com:1234567890", "1234567890"),
("prefix:middle:finalpart", "finalpart"),
],
)
def test_extract_user_id_parametrized(raw_user_id, expected):
assert RequestContext.extract_user_id(raw_user_id) == expected
9 changes: 9 additions & 0 deletions infra/feed-api/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ resource "google_cloud_run_v2_service" "mobility-feed-api" {
name = "PROJECT_ID"
value = data.google_project.project.project_id
}
env {
name = "S2S_JWT_SECRET"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[question] is the secret manually created?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this secret is set directly in the environment.

value_source {
secret_key_ref {
secret = "${upper(var.environment)}_S2S_JWT_SECRET"
version = "latest"
}
}
}
resources {
limits = {
cpu = "1"
Expand Down
Loading