Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 28 additions & 188 deletions src/api/_util/resourcelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from ...models.organization import Organization
from ...models.project import Project
from ...models.resources import (
BranchAllocationPublic,
BranchProvisioning,
EntityType,
OrganizationLimitDefault,
ProvisioningLog,
Expand All @@ -33,108 +31,38 @@

async def delete_branch_provisioning(session: SessionDep, branch_id: Identifier, *, commit: bool = True):
await session.execute(delete(ResourceUsageMinute).where(col(ResourceUsageMinute.branch_id) == branch_id))
await session.execute(delete(BranchProvisioning).where(col(BranchProvisioning.branch_id) == branch_id))
await session.execute(delete(ProvisioningLog).where(col(ProvisioningLog.branch_id) == branch_id))

if commit:
await session.commit()


async def get_current_branch_allocations(session: SessionDep, branch: Branch) -> BranchAllocationPublic:
result = await session.execute(select(BranchProvisioning).where(BranchProvisioning.branch_id == branch.id))
allocations = list(result.scalars().all())

return BranchAllocationPublic(
branch_id=branch.id,
milli_vcpu=_select_allocation(ResourceType.milli_vcpu, allocations),
ram=_select_allocation(ResourceType.ram, allocations),
iops=_select_allocation(ResourceType.iops, allocations),
database_size=_select_allocation(ResourceType.database_size, allocations),
storage_size=_select_allocation(ResourceType.storage_size, allocations),
)


async def audit_new_branch_resource_provisioning(
session: SessionDep,
branch: Branch,
resource_type: ResourceType,
amount: int,
action: str,
reason: str | None = None,
):
timestamp = datetime.now(UTC)
new_log = ProvisioningLog(
branch_id=branch.id,
resource=resource_type,
amount=amount,
action=action,
reason=reason,
ts=timestamp,
)
await session.merge(new_log)


async def create_or_update_branch_provisioning(
async def apply_branch_resource_allocation(
session: SessionDep,
branch: Branch,
resource_requests: ResourceLimitsPublic,
*,
commit: bool = True,
) -> None:
field_map = {
ResourceType.milli_vcpu: "milli_vcpu",
ResourceType.ram: "memory",
ResourceType.iops: "iops",
ResourceType.database_size: "database_size",
ResourceType.storage_size: "storage_size",
}
requests = resource_limits_to_dict(resource_requests)
for resource_type, amount in requests.items():
if amount is None:
continue
setattr(branch, field_map[resource_type], int(amount))

# Create or update allocation
result = await session.execute(
select(BranchProvisioning).where(
BranchProvisioning.branch_id == branch.id, BranchProvisioning.resource == resource_type
)
)
allocation = result.scalars().first()

new_allocation = allocation is None
if allocation is None:
allocation = BranchProvisioning(
branch_id=branch.id,
resource=resource_type,
amount=amount,
updated_at=datetime.now(UTC),
)
else:
allocation.amount = int(amount or 0) # else won't happen since it's checked above
allocation.updated_at = datetime.now(UTC)
await session.merge(allocation)

# Create audit log entry
await audit_new_branch_resource_provisioning(
session, branch, resource_type, amount, "create" if new_allocation else "update"
)

await session.merge(branch)
if commit:
await session.commit()
await session.refresh(branch)


async def clone_branch_provisioning(session: SessionDep, source: Branch, target: Branch):
result = await session.execute(select(BranchProvisioning).where(BranchProvisioning.branch_id == source.id))
provisions = result.scalars().all()

with session.no_autoflush:
for provision in provisions:
await session.merge(
BranchProvisioning(
branch_id=target.id,
resource=provision.resource,
amount=provision.amount,
updated_at=datetime.now(UTC),
)
)
await session.commit()
await session.refresh(target)


def dict_to_resource_limits(value: Mapping[ResourceType, int | None]) -> ResourceLimitsPublic:
return ResourceLimitsPublic(
milli_vcpu=value.get(ResourceType.milli_vcpu),
Expand Down Expand Up @@ -195,15 +123,6 @@ async def get_project_resource_usage(
return _map_resource_usages(list(usages))


async def check_resource_limits(
session: SessionDep, branch: Branch, provisioning_request: ResourceLimitsPublic
) -> tuple[list[ResourceType], ResourceLimitsPublic]:
project = await branch.awaitable_attrs.project
project_id = branch.project_id
organization_id = project.organization_id
return await check_available_resources_limits(session, organization_id, project_id, provisioning_request)


async def check_available_resources_limits(
session: SessionDep,
organization_id: Identifier,
Expand Down Expand Up @@ -257,24 +176,6 @@ def format_limit_violation_details(
return "; ".join(details)


# FIXME: @Chris This call should return the limits on the branch which is only meaningful for resizing, however in this
# case, the calculation is wrong, since it includes it's own allocations. Fixing this requires a change on the
# frontend,hence it's pushed for now.
async def get_effective_branch_limits(session: SessionDep, branch: Branch) -> ResourceLimitsPublic:
organization_id = (await branch.awaitable_attrs.project).organization_id
return await get_remaining_project_resources(session, organization_id, branch.project_id)


async def get_effective_branch_creation_limits(session: SessionDep, project: Project) -> ResourceLimitsPublic:
return await get_remaining_project_resources(session, project.organization_id, project.id)


async def get_effective_project_creation_limits(
session: SessionDep, organization: Organization
) -> ResourceLimitsPublic:
return await get_remaining_organization_resources(session, organization.id)


async def get_remaining_organization_resources(
session: SessionDep, organization_id: Identifier, *, exclude_branch_ids: Sequence[Identifier] | None = None
) -> ResourceLimitsPublic:
Expand Down Expand Up @@ -429,17 +330,17 @@ async def get_current_organization_allocations(
*,
exclude_branch_ids: Sequence[Identifier] | None = None,
) -> dict[ResourceType, int]:
result = await session.execute(
select(BranchProvisioning).join(Branch).join(Project).where(Project.organization_id == organization_id)
)
rows = list(result.scalars().all())
statement = _allocations().join(Project).where(Project.organization_id == organization_id)
if exclude_branch_ids:
excluded = set(exclude_branch_ids)
rows = [row for row in rows if row.branch_id not in excluded]

grouped = _group_by_resource_type(rows)
branch_statuses = await _collect_branch_statuses(session, rows)
return _aggregate_group_by_resource_type(grouped, branch_statuses)
statement = statement.where(col(Branch.id).notin_(exclude_branch_ids))
row = (await session.exec(statement)).one()
return {
ResourceType.milli_vcpu: row.milli_vcpu,
ResourceType.ram: row.ram,
ResourceType.iops: row.iops,
ResourceType.database_size: row.database_size,
ResourceType.storage_size: row.storage_size,
}


async def get_current_project_allocations(
Expand All @@ -448,87 +349,26 @@ async def get_current_project_allocations(
*,
exclude_branch_ids: Sequence[Identifier] | None = None,
) -> dict[ResourceType, int]:
result = await session.execute(select(BranchProvisioning).join(Branch).where(Branch.project_id == project_id))
rows = list(result.scalars().all())
statement = _allocations().where(Branch.project_id == project_id)
if exclude_branch_ids:
excluded = set(exclude_branch_ids)
rows = [row for row in rows if row.branch_id not in excluded]

grouped = _group_by_resource_type(rows)
branch_statuses = await _collect_branch_statuses(session, rows)
return _aggregate_group_by_resource_type(grouped, branch_statuses)


def _aggregate_group_by_resource_type(
grouped: dict[ResourceType, list[BranchProvisioning]], branch_statuses: dict[Identifier, BranchServiceStatus]
) -> dict[ResourceType, int]:
statement = statement.where(col(Branch.id).notin_(exclude_branch_ids))
row = (await session.exec(statement)).one()
return {
resource_type: sum(
allocation.amount
for allocation in allocations
if (allocation.branch_id is not None)
and (
branch_statuses.get(allocation.branch_id)
not in {BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING}
)
)
for resource_type, allocations in grouped.items()
ResourceType.milli_vcpu: row.milli_vcpu,
ResourceType.ram: row.ram,
ResourceType.iops: row.iops,
ResourceType.database_size: row.database_size,
ResourceType.storage_size: row.storage_size,
}


async def _collect_branch_statuses(
_session: SessionDep, rows: list[BranchProvisioning]
) -> dict[Identifier, BranchServiceStatus]:
branch_ids = {row.branch_id for row in rows if row.branch_id is not None}
if not branch_ids:
return {}

from ..organization.project import branch as branch_module

statuses: dict[Identifier, BranchServiceStatus] = {}
for branch_id in branch_ids:
statuses[branch_id] = await branch_module.refresh_branch_status(branch_id)
return statuses


def _group_by_resource_type(allocations: list[BranchProvisioning]) -> dict[ResourceType, list[BranchProvisioning]]:
result: dict[ResourceType, list[BranchProvisioning]] = {}
for allocation in allocations:
result.setdefault(allocation.resource, []).append(allocation)
return result


def _map_resource_allocation(provisioning_list: list[BranchProvisioning]) -> dict[ResourceType, int]:
result: dict[ResourceType, int] = {}
for resource_type in ResourceType:
result[resource_type] = _select_resource_allocation_or_zero(resource_type, provisioning_list)
return result


def _select_resource_allocation_or_zero(resource_type: ResourceType, allocations: list[BranchProvisioning]):
value: int | None = None
for allocation in allocations:
if allocation.resource == resource_type:
if value is not None:
raise ValueError(f"Multiple allocations entries for resource type {resource_type.name}")
value = allocation.amount
return value if value is not None else 0


def _map_resource_usages(usages: list[ResourceUsageMinute]) -> dict[ResourceType, int]:
result: dict[ResourceType, int] = {}
for usage in usages:
result[usage.resource] = result.get(usage.resource, 0) + usage.amount
return result


def _select_allocation(resource_type: ResourceType, allocations: list[BranchProvisioning]):
for allocation in allocations:
if allocation.resource == resource_type:
return allocation.amount
return None


def _optional_min(a: int | None, b: int | None) -> int | None:
values = [x for x in (a, b) if x is not None]
return min(values) if values else None
Expand Down
42 changes: 11 additions & 31 deletions src/api/organization/project/branch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,14 @@
DatabaseInformation,
PgbouncerConfig,
)
from .....models.resources import BranchAllocationPublic, ResourceLimitsPublic, ResourceType
from .....models.resources import ResourceLimitsPublic, ResourceType
from ...._util import Conflict, Forbidden, NotFound, Unauthenticated, url_path_for
from ...._util.backups import copy_branch_backup_schedules, delete_branch_backups, ensure_branch_pitr_schedule
from ...._util.resourcelimit import (
apply_branch_resource_allocation,
check_available_resources_limits,
create_or_update_branch_provisioning,
delete_branch_provisioning,
format_limit_violation_details,
get_current_branch_allocations,
)
from ...._util.role import clone_user_role_assignment
from ....auth import security
Expand Down Expand Up @@ -433,25 +432,6 @@ def _normalize_size_from_source(value: int | None) -> int | None:
return storage_backend_bytes_to_db_bytes(value)


def _base_deployment_resources(
source: Branch,
source_limits: BranchAllocationPublic | None,
) -> _DeploymentResourceValues:
def _value_from_limits(attribute: str, fallback: Any) -> Any:
if source_limits is None:
return fallback
limit_value = getattr(source_limits, attribute)
return fallback if limit_value is None else limit_value

return {
"database_size": _normalize_size_from_source(_value_from_limits("database_size", source.database_size)),
"storage_size": _normalize_size_from_source(_value_from_limits("storage_size", source.storage_size)),
"milli_vcpu": _value_from_limits("milli_vcpu", source.milli_vcpu),
"memory_bytes": _value_from_limits("ram", source.memory),
"iops": _value_from_limits("iops", source.iops),
}


def _apply_overrides_to_resources(
base_values: _DeploymentResourceValues,
*,
Expand Down Expand Up @@ -531,10 +511,15 @@ def _validate_deployment_requirements(
def _deployment_parameters_from_source(
source: Branch,
*,
source_limits: BranchAllocationPublic | None = None,
overrides: BranchSourceDeploymentParameters | None = None,
) -> DeploymentParameters:
resource_values = _base_deployment_resources(source, source_limits)
resource_values: _DeploymentResourceValues = {
"database_size": _normalize_size_from_source(source.database_size),
"storage_size": _normalize_size_from_source(source.storage_size),
"milli_vcpu": source.milli_vcpu,
"memory_bytes": source.memory,
"iops": source.iops,
}
enable_file_storage = source.enable_file_storage

resource_values, enable_file_storage = _apply_overrides_to_resources(
Expand Down Expand Up @@ -1402,12 +1387,7 @@ async def create( # noqa: C901
# Clone sizing is validated against the source volume size, not source branch allocation.
overrides_for_clone = source_overrides.model_copy(update={"database_size": None})

source_limits = await get_current_branch_allocations(session, source)
clone_parameters = _deployment_parameters_from_source(
source,
source_limits=source_limits,
overrides=overrides_for_clone,
)
clone_parameters = _deployment_parameters_from_source(source, overrides=overrides_for_clone)
if is_restore:
if backup_entry is None:
raise AssertionError("backup_entry required for restore branch creation")
Expand Down Expand Up @@ -1489,7 +1469,7 @@ async def create( # noqa: C901
await ensure_branch_pitr_schedule(session, entity)

# Configure allocations
await create_or_update_branch_provisioning(session, entity, resource_requests)
await apply_branch_resource_allocation(session, entity, resource_requests)

entity_url = url_path_for(
request,
Expand Down
4 changes: 2 additions & 2 deletions src/api/organization/project/branch/resize_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .....models.branch import Branch, BranchServiceStatus
from .....models.resources import ResourceLimitsPublic
from .....worker import app
from ...._util.resourcelimit import create_or_update_branch_provisioning
from ...._util.resourcelimit import apply_branch_resource_allocation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,7 +51,7 @@ async def _apply_succeeded_fields(session: object, branch: Branch, succeeded: di
continue
setattr(branch, branch_attr, value)
limits_kwarg = _FIELD_TO_RESOURCE_LIMITS_KWARG[field]
await create_or_update_branch_provisioning(
await apply_branch_resource_allocation(
session, # type: ignore[arg-type]
branch,
ResourceLimitsPublic(**{limits_kwarg: value}),
Expand Down
Loading
Loading