Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/sagemaker/hyperpod/cli/cluster_stack_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
All other functions are private implementation details and should not be used directly.
"""

import boto3
from sagemaker.hyperpod.common.utils import create_boto3_client
import click
import logging
from typing import List, Dict, Any, Optional, Tuple, Callable
Expand Down Expand Up @@ -53,7 +53,7 @@ def _get_stack_resources(stack_name: str, region: str, logger: Optional[logging.
if logger:
logger.debug(f"Fetching resources for stack '{stack_name}' in region '{region}'")

cf_client = boto3.client('cloudformation', region_name=region)
cf_client = create_boto3_client('cloudformation', region_name=region)
try:
resources_response = cf_client.list_stack_resources(StackName=stack_name)
resources = resources_response.get('StackResourceSummaries', [])
Expand Down Expand Up @@ -208,7 +208,7 @@ def _handle_partial_deletion_failure(stack_name: str, region: str, original_reso
message_callback("✗ Stack deletion failed")

try:
cf_client = boto3.client('cloudformation', region_name=region)
cf_client = create_boto3_client('cloudformation', region_name=region)
current_resources_response = cf_client.list_stack_resources(StackName=stack_name)
current_resources = current_resources_response.get('StackResourceSummaries', [])

Expand Down Expand Up @@ -273,7 +273,7 @@ def _perform_stack_deletion(stack_name: str, region: str, retain_list: List[str]
if retain_list:
logger.debug(f"Retaining resources: {retain_list}")

cf_client = boto3.client('cloudformation', region_name=region)
cf_client = create_boto3_client('cloudformation', region_name=region)

delete_params = {'StackName': stack_name}
if retain_list:
Expand Down
10 changes: 7 additions & 3 deletions src/sagemaker/hyperpod/cli/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
)
from sagemaker.hyperpod.common.utils import (
get_cluster_context as get_cluster_context_util,
_resolve_region,
)
from sagemaker.hyperpod.observability.utils import (
get_monitoring_config,
Expand Down Expand Up @@ -171,7 +172,8 @@ def list_cluster(
user_agent_extra=get_user_agent_extra_suffix()
)

session = boto3.Session(region_name=region) if region else boto3.Session()
region = _resolve_region(region)
session = boto3.Session(region_name=region)
if not validator.validate_aws_credential(session):
logger.error("Failed to list clusters capacity due to invalid AWS credentials.")
sys.exit(1)
Expand Down Expand Up @@ -581,7 +583,8 @@ def timeout_handler(signum, frame):
botocore_config = botocore.config.Config(
user_agent_extra=get_user_agent_extra_suffix()
)
session = boto3.Session(region_name=region) if region else boto3.Session()
region = _resolve_region(region)
session = boto3.Session(region_name=region)
if not validator.validate_aws_credential(session):
logger.error("Cannot connect to HyperPod cluster due to aws credentials error")
sys.exit(1)
Expand Down Expand Up @@ -708,7 +711,8 @@ def describe_cluster(cluster_name: str, debug: bool, region: str) -> None:
botocore_config = botocore.config.Config(
user_agent_extra=get_user_agent_extra_suffix()
)
session = boto3.Session(region_name=region) if region else boto3.Session()
region = _resolve_region(region)
session = boto3.Session(region_name=region)
sm_client = get_sagemaker_client(session, botocore_config)

# Get cluster details using SageMaker client
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/hyperpod/cli/commands/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import click
import json
import boto3
from sagemaker.hyperpod.common.utils import create_boto3_client
from typing import Optional
from tabulate import tabulate

Expand Down Expand Up @@ -89,7 +89,7 @@ def custom_invoke(
except json.JSONDecodeError:
raise click.ClickException("--body must be valid JSON")

rt = boto3.client("sagemaker-runtime")
rt = create_boto3_client("sagemaker-runtime")

try:
endpoint = Endpoint.get(endpoint_name)
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/hyperpod/cli/service/get_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from typing import Optional
import boto3
from sagemaker.hyperpod.common.utils import create_boto3_client

from sagemaker.hyperpod.cli.clients.kubernetes_client import (
KubernetesClient,
Expand Down Expand Up @@ -129,7 +129,7 @@ def get_log_url(self, eks_cluster_name, region, node_name, pod_name, namespace,
return console_prefix + log_group_prefix + log_stream

def is_container_insights_addon_enabled(self, eks_cluster_name):
response = boto3.client("eks").list_addons(clusterName=eks_cluster_name, maxResults=50)
response = create_boto3_client("eks").list_addons(clusterName=eks_cluster_name, maxResults=50)
if AMAZON_ClOUDWATCH_OBSERVABILITY in response.get('addons', []):
return True
else:
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/hyperpod/common/cli_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,10 @@ def _is_valid_jumpstart_model_id(model_id: str) -> bool:
Uses same SageMaker API that's already being called during creation.
"""
try:
import boto3
from botocore.exceptions import ClientError
from sagemaker.hyperpod.common.utils import create_boto3_client

sagemaker_client = boto3.client('sagemaker')
sagemaker_client = create_boto3_client('sagemaker')

# Use same API call that's failing in the current code
sagemaker_client.describe_hub_content(
Expand Down
38 changes: 31 additions & 7 deletions src/sagemaker/hyperpod/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_region_from_eks_arn(arn: str) -> str:


def get_jumpstart_model_instance_types(model_id, region) -> List[str]:
client = boto3.client("sagemaker", region_name=region)
client = create_boto3_client("sagemaker", region_name=region)

response = client.describe_hub_content(
HubName="SageMakerPublicHub", HubContentType="Model", HubContentName=model_id
Expand All @@ -145,7 +145,7 @@ def get_jumpstart_model_instance_types(model_id, region) -> List[str]:
def get_cluster_instance_types(cluster, region) -> set:
instance_types = set({})

sagemaker_client = boto3.client("sagemaker", region_name=region)
sagemaker_client = create_boto3_client("sagemaker", region_name=region)
response = sagemaker_client.describe_cluster(ClusterName=cluster)

for instance_group in response["InstanceGroups"]:
Expand Down Expand Up @@ -278,7 +278,7 @@ def set_cluster_context(
logger = logging.getLogger(__name__)
logger = setup_logging(logger)

client = boto3.client("sagemaker", region_name=region)
client = create_boto3_client("sagemaker", region_name=region)

if not is_eks_orchestrator(client, cluster_name):
raise ValueError(f"Cluster '{cluster_name}' is not EKS-orchestrated. HyperPod CLI only supports EKS-orchestrated clusters.")
Expand Down Expand Up @@ -309,7 +309,7 @@ def get_cluster_context():
def list_clusters(
region: Optional[str] = None,
):
client = boto3.client("sagemaker", region_name=region)
client = create_boto3_client("sagemaker", region_name=region)
clusters = client.list_clusters()

eks_clusters = []
Expand All @@ -330,7 +330,7 @@ def get_current_cluster():
region = get_region_from_eks_arn(current_context)

hyperpod_clusters = list_clusters(region)["Eks"]
client = boto3.client("sagemaker", region_name=region)
client = create_boto3_client("sagemaker", region_name=region)

for cluster_name in hyperpod_clusters:
if not is_eks_orchestrator(client, cluster_name):
Expand All @@ -356,18 +356,42 @@ def get_current_region():
except:
return get_aws_default_region()

def _resolve_region(region_name: Optional[str] = None) -> Optional[str]:
"""Resolve AWS region using the following fallback order:
1. Explicit region_name parameter (from --region flag)
2. AWS_REGION env var
3. AWS_DEFAULT_REGION / ~/.aws/config (standard boto3 chain)
4. Region from current cluster context (last resort)
"""
if region_name:
return region_name

aws_region_env = os.environ.get('AWS_REGION')
if aws_region_env:
return aws_region_env

boto3_region = boto3.session.Session().region_name
if boto3_region:
return boto3_region

try:
return get_region_from_eks_arn(get_cluster_context())
except Exception:
return None

def create_boto3_client(service_name: str, region_name: Optional[str] = None, **kwargs):
"""Create a boto3 client with smart region handling.

Args:
service_name (str): AWS service name (e.g., 'sagemaker', 'eks')
region_name (Optional[str]): AWS region. If None, uses AWS default
region_name (Optional[str]): AWS region. If None, resolved via
AWS_REGION env var, boto3 defaults, or cluster context.
**kwargs: Additional boto3 client parameters

Returns:
boto3 client instance
"""
return boto3.client(service_name, region_name=region_name or boto3.session.Session().region_name, **kwargs)
return boto3.client(service_name, region_name=_resolve_region(region_name), **kwargs)

def region_to_az_ids(region_code: str):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import absolute_import

import time
import boto3
from sagemaker.hyperpod.common.utils import create_boto3_client
import itables
import pandas
import logging
Expand All @@ -32,8 +32,8 @@ class ModelDataLoader:
MAX_RESULTS_PER_CALL = 100

def __init__(self, region: str, hub_name: str = "SageMakerPublicHub"):
config = Config(region_name=region, retries={"max_attempts": 10, "mode": "adaptive"})
self.client = boto3.client("sagemaker", config=config)
config = Config(retries={"max_attempts": 10, "mode": "adaptive"})
self.client = create_boto3_client("sagemaker", region_name=region, config=config)
self.hub_name = hub_name
self.all_data = []
self.next_token = None
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/hyperpod/observability/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import re
from typing import Optional

import boto3
import yaml

from sagemaker.hyperpod.common.utils import create_boto3_client
from sagemaker.hyperpod.observability.constants import AMAZON_HYPERPOD_OBSERVABILITY, GRAFANA_DASHBOARD_UID
from sagemaker.hyperpod.observability.MonitoringConfig import MonitoringConfig
# ToDO : move below functions to SDK util method instead of importing from CLI
from sagemaker.hyperpod.cli.utils import get_eks_cluster_name, get_hyperpod_cluster_region

def is_observability_addon_enabled(eks_cluster_name):
response = boto3.client("eks").list_addons(clusterName=eks_cluster_name, maxResults=50)
response = create_boto3_client("eks").list_addons(clusterName=eks_cluster_name, maxResults=50)
if AMAZON_HYPERPOD_OBSERVABILITY in response.get('addons', []):
return True
else:
Expand Down Expand Up @@ -41,7 +41,7 @@ def get_monitoring_config() -> Optional[MonitoringConfig]:
eks_cluster_name = get_eks_cluster_name()
if not is_observability_addon_enabled(eks_cluster_name):
return None
response = boto3.client("eks").describe_addon(clusterName=eks_cluster_name, addonName=AMAZON_HYPERPOD_OBSERVABILITY)
response = create_boto3_client("eks").describe_addon(clusterName=eks_cluster_name, addonName=AMAZON_HYPERPOD_OBSERVABILITY)
config_values = yaml.safe_load(response['addon']['configurationValues'])

try:
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/hyperpod/space/hyperpod_space.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import yaml
import boto3
from sagemaker.hyperpod.common.utils import create_boto3_client
from typing import List, Optional, ClassVar, Dict, Set, Any
from pydantic import BaseModel, Field, ConfigDict, model_validator
from kubernetes import client, config
Expand Down Expand Up @@ -429,7 +430,7 @@ def list(cls, namespace: Optional[str] = None) -> List["HPSpace"]:
namespace = get_default_namespace()

# Get caller identity
sts_client = boto3.client('sts')
sts_client = create_boto3_client('sts')
caller_identity = sts_client.get_caller_identity()
caller_arn = caller_identity['Arn']

Expand Down
10 changes: 5 additions & 5 deletions test/unit_tests/cli/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,15 +387,15 @@ def test_custom_create_missing_required_args():


@patch("sagemaker.hyperpod.cli.commands.inference.Endpoint.get")
@patch("sagemaker.hyperpod.cli.commands.inference.boto3")
def test_custom_invoke_success(mock_boto3, mock_endpoint_get):
@patch("sagemaker.hyperpod.cli.commands.inference.create_boto3_client")
def test_custom_invoke_success(mock_create_client, mock_endpoint_get):
mock_endpoint = Mock()
mock_endpoint.endpoint_status = "InService"
mock_endpoint_get.return_value = mock_endpoint

mock_body = Mock()
mock_body.read.return_value.decode.return_value = '{"ok": true}'
mock_boto3.client.return_value.invoke_endpoint.return_value = {"Body": mock_body}
mock_create_client.return_value.invoke_endpoint.return_value = {"Body": mock_body}

runner = CliRunner()
result = runner.invoke(
Expand All @@ -406,8 +406,8 @@ def test_custom_invoke_success(mock_boto3, mock_endpoint_get):
assert '"ok": true' in result.output


@patch("sagemaker.hyperpod.cli.commands.inference.boto3")
def test_custom_invoke_invalid_json(mock_boto3):
@patch("sagemaker.hyperpod.cli.commands.inference.create_boto3_client")
def test_custom_invoke_invalid_json(mock_create_client):
runner = CliRunner()
result = runner.invoke(custom_invoke, ["--endpoint-name", "ep", "--body", "bad"])
assert result.exit_code != 0
Expand Down
39 changes: 38 additions & 1 deletion test/unit_tests/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_cluster_context,
parse_client_kubernetes_version,
is_kubernetes_version_compatible,
_resolve_region,
)
from kubernetes.client.exceptions import ApiException
from pydantic import ValidationError
Expand Down Expand Up @@ -442,4 +443,40 @@ def test_get_cluster_context_success(self, mock_list_contexts):
result = get_cluster_context()

self.assertEqual(result, "arn:aws:eks:us-west-2:123456789012:cluster/my-cluster")
mock_list_contexts.assert_called_once()
mock_list_contexts.assert_called_once()


class TestResolveRegion(unittest.TestCase):
"""Test the _resolve_region function"""

def test_explicit_region_takes_precedence(self):
with patch.dict('os.environ', {'AWS_REGION': 'us-east-1'}):
assert _resolve_region('eu-west-1') == 'eu-west-1'

@patch.dict('os.environ', {'AWS_REGION': 'us-west-2'}, clear=False)
@patch('sagemaker.hyperpod.common.utils.boto3.session.Session')
def test_aws_region_env_var(self, mock_session):
assert _resolve_region() == 'us-west-2'
mock_session.assert_not_called()

@patch.dict('os.environ', {}, clear=True)
@patch('sagemaker.hyperpod.common.utils.boto3.session.Session')
def test_boto3_default_region_fallback(self, mock_session):
mock_session.return_value.region_name = 'ap-southeast-1'
assert _resolve_region() == 'ap-southeast-1'

@patch('sagemaker.hyperpod.common.utils.get_cluster_context')
@patch.dict('os.environ', {}, clear=True)
@patch('sagemaker.hyperpod.common.utils.boto3.session.Session')
def test_cluster_context_fallback(self, mock_session, mock_context):
mock_session.return_value.region_name = None
mock_context.return_value = 'arn:aws:eks:us-west-2:123456789012:cluster/my-cluster'
assert _resolve_region() == 'us-west-2'

@patch('sagemaker.hyperpod.common.utils.get_cluster_context')
@patch.dict('os.environ', {}, clear=True)
@patch('sagemaker.hyperpod.common.utils.boto3.session.Session')
def test_returns_none_when_nothing_configured(self, mock_session, mock_context):
mock_session.return_value.region_name = None
mock_context.side_effect = Exception("no context")
assert _resolve_region() is None
Loading