Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/dstack/_internal/core/backends/aws/backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import botocore.exceptions

from dstack._internal.core.backends.aws.compute import AWSCompute
Expand All @@ -11,9 +13,12 @@ class AWSBackend(Backend):
TYPE = BackendType.AWS
COMPUTE_CLASS = AWSCompute

def __init__(self, config: AWSConfig):
def __init__(self, config: AWSConfig, compute: Optional[AWSCompute] = None):
self.config = config
self._compute = AWSCompute(self.config)
if compute is not None:
self._compute = compute
else:
self._compute = AWSCompute(self.config)
self._check_credentials()

def compute(self) -> AWSCompute:
Expand Down
78 changes: 43 additions & 35 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import threading
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple

import boto3
Expand All @@ -19,6 +20,8 @@
)
from dstack._internal.core.backends.base.compute import (
Compute,
ComputeCache,
ComputeTTLCache,
ComputeWithAllOffersCached,
ComputeWithCreateInstanceSupport,
ComputeWithGatewaySupport,
Expand Down Expand Up @@ -94,6 +97,11 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs):
return hashkey(*args, **kwargs)


@dataclass
class AWSQuotasCache(ComputeTTLCache):
execution_lock: threading.Lock = field(default_factory=threading.Lock)


class AWSCompute(
ComputeWithAllOffersCached,
ComputeWithCreateInstanceSupport,
Expand All @@ -106,7 +114,12 @@ class AWSCompute(
ComputeWithVolumeSupport,
Compute,
):
def __init__(self, config: AWSConfig):
def __init__(
self,
config: AWSConfig,
quotas_cache: Optional[AWSQuotasCache] = None,
zones_cache: Optional[ComputeCache] = None,
):
super().__init__()
self.config = config
if isinstance(config.creds, AWSAccessKeyCreds):
Expand All @@ -119,23 +132,18 @@ def __init__(self, config: AWSConfig):
# Caches to avoid redundant API calls when provisioning many instances
# get_offers is already cached but we still cache its sub-functions
# with more aggressive/longer caches.
self._offers_post_filter_cache_lock = threading.Lock()
self._offers_post_filter_cache = TTLCache(maxsize=10, ttl=180)
self._get_regions_to_quotas_cache_lock = threading.Lock()
self._get_regions_to_quotas_execution_lock = threading.Lock()
self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300)
self._get_regions_to_zones_cache_lock = threading.Lock()
self._get_regions_to_zones_cache = Cache(maxsize=10)
self._get_vpc_id_subnet_id_or_error_cache_lock = threading.Lock()
self._get_vpc_id_subnet_id_or_error_cache = TTLCache(maxsize=100, ttl=600)
self._get_maximum_efa_interfaces_cache_lock = threading.Lock()
self._get_maximum_efa_interfaces_cache = Cache(maxsize=100)
self._get_subnets_availability_zones_cache_lock = threading.Lock()
self._get_subnets_availability_zones_cache = Cache(maxsize=100)
self._create_security_group_cache_lock = threading.Lock()
self._create_security_group_cache = TTLCache(maxsize=100, ttl=600)
self._get_image_id_and_username_cache_lock = threading.Lock()
self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600)
self._offers_post_filter_cache = ComputeTTLCache(cache=TTLCache(maxsize=10, ttl=180))
if quotas_cache is None:
quotas_cache = AWSQuotasCache(cache=TTLCache(maxsize=10, ttl=600))
self._regions_to_quotas_cache = quotas_cache
if zones_cache is None:
zones_cache = ComputeCache(cache=Cache(maxsize=10))
self._regions_to_zones_cache = zones_cache
self._vpc_id_subnet_id_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600))
self._maximum_efa_interfaces_cache = ComputeCache(cache=Cache(maxsize=100))
self._subnets_availability_zones_cache = ComputeCache(cache=Cache(maxsize=100))
self._security_group_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600))
self._image_id_and_username_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600))

def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
offers = get_catalog_offers(
Expand All @@ -144,7 +152,7 @@ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability
extra_filter=_supported_instances,
)
regions = list(set(i.region for i in offers))
with self._get_regions_to_quotas_execution_lock:
with self._regions_to_quotas_cache.execution_lock:
# Cache lock does not prevent concurrent execution.
# We use a separate lock to avoid requesting quotas in parallel and hitting rate limits.
regions_to_quotas = self._get_regions_to_quotas(self.session, regions)
Expand Down Expand Up @@ -173,9 +181,9 @@ def _get_offers_cached_key(self, requirements: Requirements) -> int:
return hash(requirements.json())

@cachedmethod(
cache=lambda self: self._offers_post_filter_cache,
cache=lambda self: self._offers_post_filter_cache.cache,
key=_get_offers_cached_key,
lock=lambda self: self._offers_post_filter_cache_lock,
lock=lambda self: self._offers_post_filter_cache.lock,
)
def get_offers_post_filter(
self, requirements: Requirements
Expand Down Expand Up @@ -789,9 +797,9 @@ def _get_regions_to_quotas_key(
return hashkey(tuple(regions))

@cachedmethod(
cache=lambda self: self._get_regions_to_quotas_cache,
cache=lambda self: self._regions_to_quotas_cache.cache,
key=_get_regions_to_quotas_key,
lock=lambda self: self._get_regions_to_quotas_cache_lock,
lock=lambda self: self._regions_to_quotas_cache.lock,
)
def _get_regions_to_quotas(
self,
Expand All @@ -808,9 +816,9 @@ def _get_regions_to_zones_key(
return hashkey(tuple(regions))

@cachedmethod(
cache=lambda self: self._get_regions_to_zones_cache,
cache=lambda self: self._regions_to_zones_cache.cache,
key=_get_regions_to_zones_key,
lock=lambda self: self._get_regions_to_zones_cache_lock,
lock=lambda self: self._regions_to_zones_cache.lock,
)
def _get_regions_to_zones(
self,
Expand All @@ -832,9 +840,9 @@ def _get_vpc_id_subnet_id_or_error_cache_key(
)

@cachedmethod(
cache=lambda self: self._get_vpc_id_subnet_id_or_error_cache,
cache=lambda self: self._vpc_id_subnet_id_cache.cache,
key=_get_vpc_id_subnet_id_or_error_cache_key,
lock=lambda self: self._get_vpc_id_subnet_id_or_error_cache_lock,
lock=lambda self: self._vpc_id_subnet_id_cache.lock,
)
def _get_vpc_id_subnet_id_or_error(
self,
Expand All @@ -853,9 +861,9 @@ def _get_vpc_id_subnet_id_or_error(
)

@cachedmethod(
cache=lambda self: self._get_maximum_efa_interfaces_cache,
cache=lambda self: self._maximum_efa_interfaces_cache.cache,
key=_ec2client_cache_methodkey,
lock=lambda self: self._get_maximum_efa_interfaces_cache_lock,
lock=lambda self: self._maximum_efa_interfaces_cache.lock,
)
def _get_maximum_efa_interfaces(
self,
Expand All @@ -877,9 +885,9 @@ def _get_subnets_availability_zones_key(
return hashkey(region, tuple(subnet_ids))

@cachedmethod(
cache=lambda self: self._get_subnets_availability_zones_cache,
cache=lambda self: self._subnets_availability_zones_cache.cache,
key=_get_subnets_availability_zones_key,
lock=lambda self: self._get_subnets_availability_zones_cache_lock,
lock=lambda self: self._subnets_availability_zones_cache.lock,
)
def _get_subnets_availability_zones(
self,
Expand All @@ -893,9 +901,9 @@ def _get_subnets_availability_zones(
)

@cachedmethod(
cache=lambda self: self._create_security_group_cache,
cache=lambda self: self._security_group_cache.cache,
key=_ec2client_cache_methodkey,
lock=lambda self: self._create_security_group_cache_lock,
lock=lambda self: self._security_group_cache.lock,
)
def _create_security_group(
self,
Expand Down Expand Up @@ -923,9 +931,9 @@ def _get_image_id_and_username_cache_key(
)

@cachedmethod(
cache=lambda self: self._get_image_id_and_username_cache,
cache=lambda self: self._image_id_and_username_cache.cache,
key=_get_image_id_and_username_cache_key,
lock=lambda self: self._get_image_id_and_username_cache_lock,
lock=lambda self: self._image_id_and_username_cache.lock,
)
def _get_image_id_and_username(
self,
Expand Down
15 changes: 14 additions & 1 deletion src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from pathlib import Path
Expand All @@ -14,7 +15,7 @@
import git
import requests
import yaml
from cachetools import TTLCache, cachedmethod
from cachetools import Cache, TTLCache, cachedmethod
from gpuhunt import CPUArchitecture

from dstack._internal import settings
Expand Down Expand Up @@ -89,6 +90,18 @@ def to_cpu_architecture(self) -> CPUArchitecture:
assert False, self


@dataclass
class ComputeCache:
cache: Cache
lock: threading.Lock = field(default_factory=threading.Lock)


@dataclass
class ComputeTTLCache:
cache: TTLCache
lock: threading.Lock = field(default_factory=threading.Lock)


class Compute(ABC):
"""
A base class for all compute implementations with minimal features.
Expand Down
18 changes: 8 additions & 10 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import concurrent.futures
import json
import re
import threading
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
Expand All @@ -19,6 +18,7 @@
from dstack import version
from dstack._internal.core.backends.base.compute import (
Compute,
ComputeTTLCache,
ComputeWithAllOffersCached,
ComputeWithCreateInstanceSupport,
ComputeWithGatewaySupport,
Expand Down Expand Up @@ -127,11 +127,9 @@ def __init__(self, config: GCPConfig):
credentials=self.credentials
)
self.reservations_client = compute_v1.ReservationsClient(credentials=self.credentials)
self._usable_subnets_cache_lock = threading.Lock()
self._usable_subnets_cache = TTLCache(maxsize=1, ttl=120)
self._find_reservation_cache_lock = threading.Lock()
# smaller TTL, since we check the reservation's in_use_count, which can change often
self._find_reservation_cache = TTLCache(maxsize=8, ttl=20)
self._usable_subnets_cache = ComputeTTLCache(cache=TTLCache(maxsize=1, ttl=120))
# Smaller TTL since we check the reservation's in_use_count, which can change often
self._reservation_cache = ComputeTTLCache(cache=TTLCache(maxsize=8, ttl=20))

def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]:
regions = get_or_error(self.config.regions)
Expand Down Expand Up @@ -948,8 +946,8 @@ def _get_roce_subnets(
return nic_subnets

@cachedmethod(
cache=lambda self: self._usable_subnets_cache,
lock=lambda self: self._usable_subnets_cache_lock,
cache=lambda self: self._usable_subnets_cache.cache,
lock=lambda self: self._usable_subnets_cache.lock,
)
def _list_usable_subnets(self) -> list[compute_v1.UsableSubnetwork]:
# To avoid hitting the `ListUsable requests per minute` system limit, we fetch all subnets
Expand All @@ -969,8 +967,8 @@ def _get_vpc_subnet(self, region: str) -> Optional[str]:
)

@cachedmethod(
cache=lambda self: self._find_reservation_cache,
lock=lambda self: self._find_reservation_cache_lock,
cache=lambda self: self._reservation_cache.cache,
lock=lambda self: self._reservation_cache.lock,
)
def _find_reservation(self, configured_name: str) -> dict[str, compute_v1.Reservation]:
if match := RESERVATION_PATTERN.fullmatch(configured_name):
Expand Down
13 changes: 12 additions & 1 deletion src/dstack/_internal/server/services/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import heapq
import time
from collections.abc import Iterable, Iterator
from typing import Callable, Coroutine, Dict, List, Optional, Tuple
from uuid import UUID
Expand Down Expand Up @@ -361,7 +362,7 @@ def get_filtered_offers_with_backends(
yield (backend, offer)

logger.info("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends])
tasks = [run_async(backend.compute().get_offers, requirements) for backend in backends]
tasks = [run_async(get_offers_tracked, backend, requirements) for backend in backends]
offers_by_backend = []
for backend, result in zip(backends, await asyncio.gather(*tasks, return_exceptions=True)):
if isinstance(result, BackendError):
Expand Down Expand Up @@ -391,3 +392,13 @@ def check_backend_type_available(backend_type: BackendType):
" Ensure that backend dependencies are installed."
f" Available backends: {[b.value for b in list_available_backend_types()]}."
)


def get_offers_tracked(
backend: Backend, requirements: Requirements
) -> Iterator[InstanceOfferWithAvailability]:
start = time.time()
res = backend.compute().get_offers(requirements)
duration = time.time() - start
logger.debug("Got offers from %s in %.6fs", backend.TYPE.value, duration)
return res