Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion src/aleph/sdk/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from typing_extensions import deprecated

from aleph.sdk.conf import settings
from aleph.sdk.types import Account
from aleph.sdk.types import Account, Authorization, SecurityAggregateContent
from aleph.sdk.utils import extended_json_encoder

from ..query.filters import MessageFilter, PostFilter
Expand Down Expand Up @@ -295,6 +295,30 @@ def get_program_price(
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

async def get_authorizations(self, address: str) -> list[Authorization]:
"""
Retrieves all authorizations for a specific address.
"""
# TODO: update this implementation to use `get_aggregate()` once
# https://github.com/aleph-im/aleph-sdk-python/pull/273 is merged.
# There's currently no way to detect a nonexistent aggregate in generic code just yet.
# fetch_aggregate() throws an implementation-specific ClientResponseError in case of 404.
import aiohttp

try:
security_aggregate_dict = await self.fetch_aggregate(
address=address, key="security"
)
except aiohttp.ClientResponseError as e:
if e.status == 404:
return []
raise

security_aggregate = SecurityAggregateContent.model_validate(
security_aggregate_dict
)
return security_aggregate.authorizations


class AuthenticatedAlephClient(AlephClient):
account: Account
Expand Down Expand Up @@ -617,3 +641,35 @@ async def storage_push(self, content: Mapping) -> str:
:param content: The dict-like content to upload
"""
raise NotImplementedError()

async def update_all_authorizations(self, authorizations: list[Authorization]):
"""
Updates all authorizations for the current account.
Danger! This will replace all authorizations for the account. Use with care.

:param authorizations: List of authorizations to set. These authorizations will replace the existing ones.
"""
security_aggregate = SecurityAggregateContent(authorizations=authorizations)
await self.create_aggregate(
key="security", content=security_aggregate.model_dump()
)

async def add_authorization(self, authorization: Authorization):
"""
Adds a specific authorization for the current account.
"""
authorizations = await self.get_authorizations(self.account.get_address())
authorizations.append(authorization)
await self.update_all_authorizations(authorizations)

async def revoke_all_authorizations(self, address: str):
"""
Revokes all authorizations for a specific address.
"""
authorizations = await self.get_authorizations(self.account.get_address())
filtered_authorizations = [
authorization
for authorization in authorizations
if authorization.address != address
]
await self.update_all_authorizations(filtered_authorizations)
71 changes: 69 additions & 2 deletions src/aleph/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Union,
)

from aleph_message.models import ItemHash
from aleph_message.models import ItemHash, MessageType
from pydantic import (
BaseModel,
ConfigDict,
Expand All @@ -24,9 +24,11 @@
TypeAdapter,
field_validator,
)
from typing_extensions import runtime_checkable
from typing_extensions import Self, runtime_checkable

__all__ = (
"Authorization",
"AuthorizationBuilder",
"StorageEnum",
"Account",
"AccountFromPrivateKey",
Expand Down Expand Up @@ -406,3 +408,68 @@ class VmResources(BaseModel):
vcpus: PositiveInt
memory: PositiveInt
disk_mib: PositiveInt


class Authorization(BaseModel):
"""A single authorization entry for delegated access."""

address: str
chain: Optional[Chain] = None
channels: list[str] = []
types: list[MessageType] = []
post_types: list[str] = []
aggregate_keys: list[str] = []


class AuthorizationBuilder:
def __init__(self, address: str):
self._address: str = address
self._chain: Optional[Chain] = None
self._channels: list[str] = []
self._message_types: list[MessageType] = []
self._post_types: list[str] = []
self._aggregate_keys: list[str] = []

def chain(self, chain: Chain) -> Self:
self._chain = chain
return self

def channel(self, channel: str) -> Self:
self._channels.append(channel)
return self

def message_type(self, message_type: MessageType) -> Self:
self._message_types.append(message_type)
return self

def post_type(self, post_type: str) -> Self:
if MessageType.post not in self._message_types:
raise ValueError(
"Cannot set post_type without allowing POST message type first"
)
self._post_types.append(post_type)
return self

def aggregate_key(self, aggregate_key: str) -> Self:
if MessageType.aggregate not in self._message_types:
raise ValueError(
"Cannot set post_type without allowing AGGREGATE message type first"
)
self._aggregate_keys.append(aggregate_key)
return self

def build(self) -> Authorization:
return Authorization(
address=self._address,
chain=self._chain,
channels=self._channels,
types=self._message_types,
post_types=self._post_types,
aggregate_keys=self._aggregate_keys,
)


class SecurityAggregateContent(BaseModel):
"""Content schema for the 'security' aggregate."""

authorizations: list[Authorization] = []
Loading
Loading