Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions backend/account_v2/authentication_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.shortcuts import redirect
from logs_helper.log_service import LogService
from rest_framework import status
from rest_framework.exceptions import APIException
from rest_framework.request import Request
from rest_framework.response import Response
from tenant_account_v2.models import OrganizationMember as OrganizationMember
Expand Down Expand Up @@ -100,6 +101,11 @@ def authorization_callback(
return self.auth_service.handle_authorization_callback(
request=request, backend=backend
)
except APIException:
# Surface DRF exceptions (e.g., AmbiguousUserException → 409) so
# the actionable detail reaches the caller instead of being lost
# behind a generic /error redirect.
raise
except Exception:
logger.exception("Error while handling authorization callback")
Comment thread
jaseemjaskp marked this conversation as resolved.
return redirect("/error")
Expand Down Expand Up @@ -430,7 +436,8 @@ def add_user_role(
self.save_organization_user_role(
user_uid=user.user.id, role=current_roles[0]
)
return current_roles[0]
return current_roles[0]
return None
else:
return None

Expand All @@ -449,7 +456,8 @@ def remove_user_role(
user_uid=organization_member.user.id,
role=current_roles[0],
)
return current_roles[0]
return current_roles[0]
return None
else:
return None

Expand Down
8 changes: 8 additions & 0 deletions backend/account_v2/custom_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,11 @@ class Forbidden(APIException):
class UserAlreadyAssociatedException(APIException):
status_code = 400
default_detail = "User is already associated with one organization."


class AmbiguousUserException(APIException):
status_code = 409
default_detail = (
"Multiple user records match this lookup. "
"Contact your administrator to resolve the duplicate."
)
1 change: 0 additions & 1 deletion backend/account_v2/tests.py

This file was deleted.

Empty file.
122 changes: 122 additions & 0 deletions backend/account_v2/tests/test_user_filter_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Regression tests for ``UserFilterRegistry``.

These tests guard the plumbing that identity-scoping plugins use to scope
User / OrganizationMember querysets. A future refactor that breaks the
registry contract (e.g., loses dedupe-on-register, swallows plugin
exceptions, mishandles unregister) would silently let cross-environment
identities leak — the very failure mode the registry exists to prevent.

No Django app registry / DB is required: the registry only depends on
``django.db.models.QuerySet`` as a static type hint, and the tests use
``unittest.mock`` to fake the queryset chain.
"""

from __future__ import annotations

import logging
import unittest
from unittest.mock import MagicMock

from account_v2.user_filter_registry import UserFilterRegistry


class UserFilterRegistryTests(unittest.TestCase):
def setUp(self) -> None:
# Class-level state — clear before every test to avoid bleed.
UserFilterRegistry.clear()
self.addCleanup(UserFilterRegistry.clear)

def test_apply_with_no_filters_is_identity(self) -> None:
qs = MagicMock(name="queryset")
self.assertIs(UserFilterRegistry.apply(qs, "user"), qs)

def test_register_appends_filter(self) -> None:
def fn(qs, kind): # noqa: ANN001 - inline test helper
return qs

UserFilterRegistry.register(fn)
self.assertIn(fn, UserFilterRegistry._filters)

def test_register_dedupes(self) -> None:
def fn(qs, kind): # noqa: ANN001
return qs

UserFilterRegistry.register(fn)
UserFilterRegistry.register(fn)
self.assertEqual(UserFilterRegistry._filters.count(fn), 1)

def test_unregister_removes_filter(self) -> None:
def fn(qs, kind): # noqa: ANN001
return qs

UserFilterRegistry.register(fn)
UserFilterRegistry.unregister(fn)
self.assertNotIn(fn, UserFilterRegistry._filters)

def test_unregister_unknown_is_noop(self) -> None:
def fn(qs, kind): # noqa: ANN001
return qs

# Must not raise.
UserFilterRegistry.unregister(fn)

def test_clear_empties_registry(self) -> None:
UserFilterRegistry.register(lambda qs, kind: qs)
UserFilterRegistry.register(lambda qs, kind: qs)
UserFilterRegistry.clear()
self.assertEqual(UserFilterRegistry._filters, [])

def test_apply_runs_filters_in_registration_order(self) -> None:
order: list[str] = []

def first(qs, kind): # noqa: ANN001
order.append("first")
return qs

def second(qs, kind): # noqa: ANN001
order.append("second")
return qs

UserFilterRegistry.register(first)
UserFilterRegistry.register(second)
UserFilterRegistry.apply(MagicMock(), "user")
self.assertEqual(order, ["first", "second"])

def test_apply_threads_filtered_queryset_through_chain(self) -> None:
qs0 = MagicMock(name="qs0")
qs1 = MagicMock(name="qs1")
qs2 = MagicMock(name="qs2")

def fn_a(qs, kind): # noqa: ANN001
self.assertIs(qs, qs0)
return qs1

def fn_b(qs, kind): # noqa: ANN001
self.assertIs(qs, qs1)
return qs2

UserFilterRegistry.register(fn_a)
UserFilterRegistry.register(fn_b)
self.assertIs(UserFilterRegistry.apply(qs0, "user"), qs2)

def test_apply_reraises_plugin_exceptions_with_attribution_log(self) -> None:
# Fail-closed semantics: a broken plugin must not silently let
# un-scoped users leak into a downstream query. The exception must
# propagate AND the offending fn must be identifiable in the log.
def broken(qs, kind): # noqa: ANN001
raise RuntimeError("simulated plugin bug")

UserFilterRegistry.register(broken)
with self.assertLogs("account_v2.user_filter_registry", level="ERROR") as cm:
with self.assertRaises(RuntimeError):
UserFilterRegistry.apply(MagicMock(), "user")
joined = "\n".join(cm.output)
self.assertIn("user_filter plugin raised", joined)
# Plugin attribution: the failing fn's repr should appear in the log
# so an operator knows which plugin to investigate.
self.assertIn("broken", joined)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()
68 changes: 53 additions & 15 deletions backend/account_v2/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@

from django.db import IntegrityError

from account_v2.custom_exceptions import AmbiguousUserException
from account_v2.models import User
from account_v2.user_filter_registry import UserFilterRegistry

Logger = logging.getLogger(__name__)

# Cap on the number of matched row PKs included in ambiguity logs to keep
# a misconfigured filter from turning the error path into a full table scan.
AMBIGUITY_LOG_LIMIT = 50


class UserService:
def __init__(
Expand Down Expand Up @@ -41,28 +47,60 @@ def update_user(self, user: User, user_id: str) -> User:
return user

def get_user_by_email(self, email: str) -> User | None:
try:
user: User = User.objects.get(email=email, auth_provider="")
return user
except User.DoesNotExist:
return None
return _resolve_unique(
User.objects.filter(email=email, auth_provider=""),
"user",
("email", email),
)

def get_user_by_user_id(self, user_id: str) -> Any:
try:
return User.objects.get(user_id=user_id)
except User.DoesNotExist:
return None
return _resolve_unique(
User.objects.filter(user_id=user_id),
"user",
("user_id", user_id),
)

def get_user_by_id(self, id: str) -> Any:
"""Retrieve a user by their ID, taking into account the schema context.

Args:
id (str): The ID of the user.
"""Retrieve a user by primary key.

Returns:
Any: The user object if found, or None if not found.
PK lookups are always unique and bypass the filter registry so
identity-scoping filters cannot hide a row whose PK is already known
(e.g., the currently authenticated admin's own row).
"""
try:
return User.objects.get(id=id)
except User.DoesNotExist:
return None


def _resolve_unique(
qs: Any,
kind: str,
lookup: tuple[str, Any],
) -> User | None:
"""Apply the user filter registry and resolve to a single row.

Raises ``AmbiguousUserException`` if more than one row matches after
filters apply — that signals either duplicate User rows or a
misconfigured identity-scoping filter, and silently picking one would
propagate the wrong identity downstream.
"""
qs = UserFilterRegistry.apply(qs, kind)
rows = list(qs[:2])
if len(rows) > 1:
# Log the matched row PKs (internal IDs, not PII) instead of the
# raw lookup value so ambiguity remains diagnosable from logs
# without expanding PII retention. Cap at AMBIGUITY_LOG_LIMIT so a
# misconfigured filter matching thousands of rows doesn't turn the
# error path into a full table scan.
pks = list(qs.values_list("pk", flat=True)[:AMBIGUITY_LOG_LIMIT])
field, _ = lookup
Logger.error(
"Ambiguous User lookup by %s (matched ≥%d rows; first %d pks=%s)",
field,
len(pks),
len(pks),
pks,
)
raise AmbiguousUserException()
Comment thread
jaseemjaskp marked this conversation as resolved.
return rows[0] if rows else None
63 changes: 63 additions & 0 deletions backend/account_v2/user_filter_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Pluggable filters for User / OrganizationMember queries.

Identity plugins can register a callable that scopes querysets to a
subset of users — for example, limiting visibility to users whose
external identity belongs to the current environment. The service
layers in `account_v2.user.UserService` and
`tenant_account_v2.organization_member_service.OrganizationMemberService`
call ``UserFilterRegistry.apply`` on each user lookup so registered
filters take effect uniformly without core having to know which plugin
is loaded.

When no filters are registered the registry is a no-op, so OSS and
development setups are unaffected.
Comment thread
jaseemjaskp marked this conversation as resolved.
"""

import logging
from collections.abc import Callable
from typing import ClassVar, Literal

from django.db.models import QuerySet

logger = logging.getLogger(__name__)

# "user" filters operate on `account_v2.User` querysets and should
# reference `user_id`. "org_member" filters operate on
# `tenant_account_v2.OrganizationMember` querysets and should reference
# `user__user_id`.
FilterKind = Literal["user", "org_member"]
Comment thread
jaseemjaskp marked this conversation as resolved.

FilterFn = Callable[[QuerySet, FilterKind], QuerySet]


class UserFilterRegistry:
_filters: ClassVar[list[FilterFn]] = []
Comment thread
jaseemjaskp marked this conversation as resolved.

@classmethod
def register(cls, fn: FilterFn) -> None:
if fn not in cls._filters:
cls._filters.append(fn)
Comment thread
greptile-apps[bot] marked this conversation as resolved.

@classmethod
def unregister(cls, fn: FilterFn) -> None:
if fn in cls._filters:
cls._filters.remove(fn)

@classmethod
def clear(cls) -> None:
"""Remove all registered filters. Intended for tests only."""
cls._filters.clear()

@classmethod
def apply(cls, qs: QuerySet, kind: FilterKind) -> QuerySet:
for fn in cls._filters:
try:
qs = fn(qs, kind)
except Exception:
logger.exception(
"user_filter plugin raised; aborting lookup (fn=%r kind=%s)",
fn,
kind,
)
raise
return qs
Loading