Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"jinja2>=3.1",
"py-key-value-aio[disk]",
"pyjwt>=2.12.1",
"argon2-cffi>=25.1.0",
"base58>=2.1.1",
"posthog>=3.0",
]
Expand Down
4 changes: 2 additions & 2 deletions src/authsome/auth/flows/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ApiKeyFlow(AuthFlow):
async def begin(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
scopes: list[str] | None = None,
Expand All @@ -39,7 +39,7 @@ async def begin(
async def resume(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
callback_data: dict[str, Any],
Expand Down
4 changes: 2 additions & 2 deletions src/authsome/auth/flows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class AuthFlow(ABC):
async def begin(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
scopes: list[str] | None = None,
Expand All @@ -62,7 +62,7 @@ async def begin(
async def resume(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
callback_data: dict[str, Any],
Expand Down
4 changes: 2 additions & 2 deletions src/authsome/auth/flows/dcr_pkce.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DcrPkceFlow(AuthFlow):
async def begin(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
scopes: list[str] | None = None,
Expand Down Expand Up @@ -81,7 +81,7 @@ async def begin(
async def resume(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
callback_data: dict[str, Any],
Expand Down
4 changes: 2 additions & 2 deletions src/authsome/auth/flows/device_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DeviceCodeFlow(AuthFlow):
async def begin(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
scopes: list[str] | None = None,
Expand Down Expand Up @@ -74,7 +74,7 @@ async def begin(
async def resume(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
callback_data: dict[str, Any],
Expand Down
4 changes: 2 additions & 2 deletions src/authsome/auth/flows/pkce.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class PkceFlow(AuthFlow):
async def begin(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
scopes: list[str] | None = None,
Expand Down Expand Up @@ -72,7 +72,7 @@ async def begin(
async def resume(
self,
provider: ProviderDefinition,
identity: str,
identity: str | None,
connection_name: str,
runtime_session: AuthSession,
callback_data: dict[str, Any],
Expand Down
6 changes: 3 additions & 3 deletions src/authsome/auth/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ConnectionRecord(BaseModel):

schema_version: int = 2
provider: str
identity: str
identity: str | None = None
principal_id: str | None = None
vault_id: str | None = None
connection_name: str
Expand Down Expand Up @@ -74,7 +74,7 @@ class ProviderMetadataRecord(BaseModel):
"""

schema_version: int = 2
identity: str
identity: str | None = None
principal_id: str | None = None
vault_id: str | None = None
provider: str
Expand All @@ -95,7 +95,7 @@ class ProviderStateRecord(BaseModel):

schema_version: int = 2
provider: str
identity: str
identity: str | None = None
principal_id: str | None = None
vault_id: str | None = None
last_refresh_at: datetime | None = None
Expand Down
4 changes: 2 additions & 2 deletions src/authsome/auth/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AuthSession(BaseModel):

session_id: str
provider: str
identity: str
identity: str | None = None
principal_id: str | None = None
connection_name: str
flow_type: str
Expand Down Expand Up @@ -61,7 +61,7 @@ async def create(
self,
*,
provider: str,
identity: str,
identity: str | None,
principal_id: str | None,
connection_name: str,
flow_type: str,
Expand Down
1 change: 1 addition & 0 deletions src/authsome/identity/principal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PrincipalRecord(BaseModel):

principal_id: str
email: str
password_hash: str | None = None
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)

Expand Down
9 changes: 9 additions & 0 deletions src/authsome/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from authsome.server.analytics import init_posthog, shutdown_posthog
from authsome.server.dependencies import (
create_app_store,
create_hosted_account_service,
create_identity_bootstrap_service,
create_identity_claim_registry,
create_ownership_resolver,
create_principal_vault_binding_registry,
create_vault,
create_vault_registry,
get_identity_registry_path,
Expand All @@ -34,6 +36,7 @@
from authsome.server.routes.identities import router as identities_router
from authsome.server.routes.providers import router as providers_router
from authsome.server.routes.proxy import router as proxy_router
from authsome.server.routes.ui import UiAuthRequiredError
from authsome.server.routes.ui import router as ui_router
from authsome.server.ui_sessions import UiSessionStore

Expand All @@ -51,6 +54,8 @@ async def lifespan(app: FastAPI):
app.state.identity_registry = IdentityRegistry(get_identity_registry_path(app.state.store.home))
app.state.vault_registry = create_vault_registry(app.state.store.home)
app.state.identity_claim_registry = create_identity_claim_registry(app.state.store.home)
app.state.principal_vault_binding_registry = create_principal_vault_binding_registry(app.state.store.home)
app.state.hosted_account_service = create_hosted_account_service(app.state.store.home)
app.state.server_base_url = get_server_base_url()
init_posthog()
app.state.identity_bootstrap = create_identity_bootstrap_service(
Expand Down Expand Up @@ -94,6 +99,10 @@ def authsome_error_handler(request: Request, exc: AuthsomeError) -> JSONResponse
def identity_registration_error_handler(request: Request, exc: IdentityRegistrationError) -> JSONResponse:
return JSONResponse(status_code=409, content={"error": "IdentityRegistrationError", "message": str(exc)})

@app.exception_handler(UiAuthRequiredError)
def ui_auth_required_handler(request: Request, exc: UiAuthRequiredError):
return exc.response

app.include_router(health_router)
app.include_router(identities_router)
app.include_router(auth_router)
Expand Down
96 changes: 83 additions & 13 deletions src/authsome/server/credential_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,11 @@ class AuthService:
def __init__(
self,
vault: Vault,
identity: str,
identity: str | None = None,
principal_id: str | None = None,
vault_id: str | None = None,
deployment_mode: str = "local",
) -> None:
if not identity:
raise ValueError("AuthService requires an explicit identity handle")
self._vault = vault
self._identity = identity
self._principal_id = principal_id
Expand All @@ -104,7 +102,13 @@ def vault(self) -> Vault:
return self._vault

@property
def identity(self) -> str:
def identity(self) -> str | None:
return self._identity

def require_identity(self) -> str:
"""Return the PoP-authenticated identity handle for identity-scoped routes."""
if self._identity is None:
raise ValueError("AuthService identity is required for this operation")
return self._identity

@property
Expand Down Expand Up @@ -308,7 +312,11 @@ async def get_connection(
key = build_store_key(vault=self._vault_id, provider=provider, record_type="connection", connection=connection)
record_json = await self._vault.get(key, collection=self._coll)
if not record_json:
raise ConnectionNotFoundError(provider=provider, connection=connection, identity=self._identity)
raise ConnectionNotFoundError(
provider=provider,
connection=connection,
identity=self._identity or self._principal_id or "hosted-ui",
)
record = self._load_connection_record(record_json, key)
if record is None:
raise AuthsomeError(
Expand Down Expand Up @@ -336,6 +344,58 @@ async def get_provider_client(self, provider: str) -> ProviderClientRecord | Non
"""
return await self._get_provider_client_credentials(provider)

async def update_provider_configuration(
self,
provider: str,
inputs: dict[str, str],
*,
vault_ids: list[str] | None = None,
) -> bool:
"""Replace stored provider credentials and revoke connections when they change."""
self._ensure_provider_client_mutation_allowed(provider)
definition = await self.get_provider(provider)
if definition.auth_type != AuthType.OAUTH2:
return False

existing = await self._get_provider_client_credentials(provider)
updated = ProviderClientRecord(provider=provider)
updated.client_id = inputs.get("client_id", existing.client_id if existing else None) or None

if "client_secret" in inputs:
secret_input = inputs["client_secret"].strip()
if secret_input:
updated.client_secret = secret_input
else:
updated.client_secret = existing.client_secret if existing else None
else:
updated.client_secret = existing.client_secret if existing else None

if definition.oauth and definition.oauth.base_url:
updated.base_url = inputs.get("base_url", existing.base_url if existing else None) or None
updated.api_url = inputs.get("api_url", existing.api_url if existing else None) or None
else:
updated.base_url = existing.base_url if existing else None
updated.api_url = existing.api_url if existing else None

updated.scopes = list(existing.scopes) if existing and existing.scopes is not None else None
updated.metadata = dict(existing.metadata) if existing else {}

changed = existing is None or any(
(
existing.client_id != updated.client_id,
existing.client_secret != updated.client_secret,
existing.base_url != updated.base_url,
existing.api_url != updated.api_url,
)
)
if not changed:
return False

if existing is not None:
await self.revoke(provider, vault_ids=vault_ids)
await self._save_provider_client_credentials(updated)
return True

async def set_default_connection(self, provider: str, connection: str) -> None:
"""Set the default connection for a provider."""
await self.get_connection(provider, connection)
Expand Down Expand Up @@ -371,41 +431,51 @@ async def get_required_inputs(
definition = await self.get_provider(provider)
flow_type = FlowType(session.flow_type)
client_record = await self._get_provider_client_credentials(provider)
provider_config_only = bool(session.payload.get("provider_config_only"))

flow_base_url = base_url or (client_record.base_url if client_record else None)
flow_client_id = client_record.client_id if client_record else None
persisted_scopes = client_record.scopes if client_record else None

fields: list[InputField] = []

if definition.oauth and definition.oauth.base_url and not flow_base_url:
if definition.oauth and definition.oauth.base_url and (provider_config_only or not flow_base_url):
fields.append(
InputField(
name="base_url",
label="Base URL",
secret=False,
default=definition.oauth.base_url,
default=flow_base_url or definition.oauth.base_url,
)
)
fields.append(
InputField(
name="api_url",
label="API Host URL",
secret=False,
default=definition.api_url or "",
default=(
client_record.api_url if client_record and client_record.api_url else definition.api_url or ""
),
)
)

if flow_type == FlowType.PKCE and not flow_client_id:
fields.append(InputField(name="client_id", label="Client ID", secret=False))
if flow_type == FlowType.PKCE and (provider_config_only or not flow_client_id):
fields.append(
InputField(
name="client_id",
label="Client ID",
secret=False,
default=flow_client_id or "",
)
)
fields.append(InputField(name="client_secret", label="Client Secret (Optional)", secret=True, default=""))
elif flow_type == FlowType.DEVICE_CODE and not flow_client_id:
elif flow_type == FlowType.DEVICE_CODE and (provider_config_only or not flow_client_id):
fields.append(
InputField(
name="client_id",
label="Client ID (leave blank for public device code flow)",
secret=False,
default="",
default=flow_client_id or "",
)
)
fields.append(InputField(name="client_secret", label="Client Secret (Optional)", secret=True, default=""))
Expand Down Expand Up @@ -795,7 +865,7 @@ def _disambiguate_export_name(
async def get_identity(self, name: str) -> str:
if name != self._identity:
raise IdentityNotFoundError(name)
return self._identity
return self.require_identity()

# ── Internal helpers ──────────────────────────────────────────────────

Expand Down
11 changes: 11 additions & 0 deletions src/authsome/server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from authsome.paths import get_authsome_home as _get_authsome_home
from authsome.paths import get_server_home as _get_server_home
from authsome.paths import get_server_log_path as _get_server_log_path
from authsome.server.hosted_auth import HostedAccountService
from authsome.server.identity_bootstrap import (
HostedIdentityBootstrapService,
IdentityBootstrapService,
Expand Down Expand Up @@ -189,6 +190,16 @@ def create_principal_vault_binding_registry(home: Path | None = None) -> Princip
return PrincipalVaultBindingRegistry(get_principal_vault_binding_registry_path(home))


def create_hosted_account_service(home: Path | None = None) -> HostedAccountService:
resolved_home = home or get_authsome_home()
return HostedAccountService(
principals=create_principal_registry(resolved_home),
vaults=create_vault_registry(resolved_home),
bindings=create_principal_vault_binding_registry(resolved_home),
jwt_secret=load_ui_session_signing_secret(resolved_home),
)


def create_ownership_resolver(home: Path | None = None) -> OwnershipResolver:
resolved_home = home or get_authsome_home()
principals = create_principal_registry(resolved_home)
Expand Down
Loading
Loading