Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@
generate_stack_name,
get_architecture_supported_by_instance_type,
get_arn_partition,
get_flexible_instance_types,
get_instance_info,
get_metadata,
get_network_interfaces_count,
get_similar_instance_types,
get_vpc_snakecase_value,
random_alphanumeric,
to_pascal_case,
Expand Down Expand Up @@ -698,7 +698,7 @@ def inject_placement_group_settings(vpc_stack, instance, region, kwargs):


def inject_flexible_instance_types_settings(instance, region, kwargs):
kwargs["flexible_instance_types"] = list({instance, *get_similar_instance_types(instance, region, 5)})
kwargs["flexible_instance_types"] = get_flexible_instance_types(instance, region)


def inject_additional_image_configs_settings(image_config, request):
Expand Down
102 changes: 102 additions & 0 deletions tests/integration-tests/framework/file_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file.
# This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied.
# See the License for the specific language governing permissions and limitations under the License.
"""Cross-process file-backed memoization decorator.

Drop-in replacement for ``functools.cache`` that persists results to a file
guarded by a :class:`filelock.FileLock`, so that callers running in separate
processes (e.g. pytest-xdist workers) share cached values instead of each
recomputing the same result.
"""

import functools
import os
import pickle
import tempfile

from filelock import FileLock


def file_cache(filename: str):
"""Decorator providing cross-process memoization backed by a file.

Works like ``functools.cache`` but persists results across processes via a
pickle file. All positional and keyword arguments must be hashable and
return values must be picklable.

Parameters
----------
filename:
Path to the cache file. If a relative path is given, it is resolved
under :func:`tempfile.gettempdir` so the cache survives a single
machine across pytest sessions and is shared by all workers.
"""
cache_path = filename if os.path.isabs(filename) else os.path.join(tempfile.gettempdir(), filename)
lock_path = cache_path + ".lock"

def decorator(func):
in_memory = {}

@functools.wraps(func)
def wrapper(*args, **kwargs):
key = (args, tuple(sorted(kwargs.items())))
if key in in_memory:
return in_memory[key]

with FileLock(lock_path):
disk_cache = _load(cache_path)
if key in disk_cache:
in_memory[key] = disk_cache[key]
return disk_cache[key]

result = func(*args, **kwargs)
disk_cache[key] = result
_dump(cache_path, disk_cache)
in_memory[key] = result
return result

def cache_clear():
in_memory.clear()
with FileLock(lock_path):
if os.path.exists(cache_path):
os.remove(cache_path)

wrapper.cache_clear = cache_clear
wrapper.__wrapped__ = func
return wrapper

return decorator


def _load(path):
if not os.path.exists(path):
return {}
try:
with open(path, "rb") as f:
return pickle.load(f)
except (EOFError, pickle.UnpicklingError):
# Corrupted cache file — start fresh.
return {}


def _dump(path, data):
# Atomic write: dump to a temp file in the same directory, then rename.
directory = os.path.dirname(path) or "."
os.makedirs(directory, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(prefix=".file_cache_", dir=directory)
try:
with os.fdopen(fd, "wb") as f:
pickle.dump(data, f)
os.replace(tmp_path, path)
except Exception:
if os.path.exists(tmp_path):
os.remove(tmp_path)
raise
2 changes: 1 addition & 1 deletion tests/integration-tests/tests/common/capacity_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def resolve_instance_with_capacity(region, az_id, instance_type, os, minutes=50,
if instance_type not in DEFAULT_INSTANCE_TYPES:
return instance_type

candidates = [instance_type] + get_similar_instance_types(instance_type)
candidates = [instance_type] + get_similar_instance_types(instance_type, region)

ec2_client = boto3.client("ec2", region_name=region)
instance_platform = "Red Hat Enterprise Linux" if "rhel" in os else "Linux/UNIX"
Expand Down
21 changes: 16 additions & 5 deletions tests/integration-tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import string
import subprocess
from datetime import datetime, timedelta
from functools import cache
from hashlib import sha1

import boto3
import requests
from assertpy import assert_that
from framework.file_cache import file_cache
from jinja2 import FileSystemLoader
from jinja2.sandbox import SandboxedEnvironment
from retrying import retry
Expand Down Expand Up @@ -1026,9 +1026,10 @@ def _get_gpu_spec(instance_type_data):
return frozenset((gpu.get("Manufacturer", ""), gpu.get("Count", 0)) for gpu in gpu_info.get("Gpus", []))


@file_cache("pcluster_similar_instance_types.cache")
def get_similar_instance_types(instance_type: str, region: str = None, max_items: int = None):
"""Return instance types compatible with ``instance_type`` in ``region``."""
ec2 = boto3.client("ec2", region_name=region)

# First, get the target instance details to use as filter criteria
target_response = ec2.describe_instance_types(InstanceTypes=[instance_type])

Expand All @@ -1046,6 +1047,7 @@ def get_similar_instance_types(instance_type: str, region: str = None, max_items
# Now query for similar instances using filters
paginator = ec2.get_paginator("describe_instance_types")
similar_instances = []
reached_max_items = False

for page in paginator.paginate(
Filters=[
Expand All @@ -1069,17 +1071,26 @@ def get_similar_instance_types(instance_type: str, region: str = None, max_items
):
similar_instances.append(instance["InstanceType"])
if max_items and len(similar_instances) >= max_items:
return similar_instances
reached_max_items = True
break
if reached_max_items:
break

logging.info(f"Retrieved instance types equivalent to {instance_type} in {region}: {similar_instances}")

return similar_instances


@cache
def get_flexible_instance_types(instance, region):
"""Return ``instance`` plus up to 5 similar instance types available in ``region``."""
return list({instance, *get_similar_instance_types(instance, region)[:5]})


def get_flexible_gpu_instance_types(instance, region):
"""Return a list of NVIDIA GPU instance types compatible with ``instance``'s architecture."""
architecture = get_architecture_supported_by_instance_type(instance, region)
gpu_instance_type = "g4dn.2xlarge" if architecture == "x86_64" else "g5g.2xlarge"
return list({gpu_instance_type, *get_similar_instance_types(gpu_instance_type, region, 5)})
return list({gpu_instance_type, *get_similar_instance_types(gpu_instance_type, region)[:5]})


def verify_cluster_node_config_version_in_ddb(region, cluster_name, instance_id, expected_version):
Expand Down
Loading