Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sagemaker-core/src/sagemaker/core/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,9 @@ def sts_regional_endpoint(region):
Returns:
str: AWS STS regional endpoint
"""
from sagemaker.core.region_validation import validate_region

validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -906,6 +909,9 @@ def aws_partition(region):
Returns:
str: partition corresponding to the region name passed in. Ex: "aws-cn"
"""
from sagemaker.core.region_validation import validate_region

validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
Expand Down
4 changes: 4 additions & 0 deletions sagemaker-core/src/sagemaker/core/helper/session_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ def _initialize(
"Must setup local AWS configuration with a region supported by SageMaker."
)

from sagemaker.core.region_validation import validate_region

validate_region(self._region_name)

# Make use of user_agent_extra field of the botocore_config object
# to append SageMaker Python SDK specific user_agent suffix
# to the current User-Agent header value from boto3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sagemaker.core.inference_config import ServerlessInferenceConfig
from sagemaker.core.training_compiler.config import TrainingCompilerConfig
from sagemaker.core.common_utils import _botocore_resolver
from sagemaker.core.region_validation import validate_region
from sagemaker.core.workflow import is_pipeline_variable
from sagemaker.core.image_retriever.image_retriever_utils import (
_config_for_framework_and_scope,
Expand Down Expand Up @@ -161,6 +162,7 @@ def retrieve_hugging_face_uri(
)
version_config = version_config.get(py_version) or version_config
registry = _registry_from_region(region, version_config["registries"])
validate_region(region)
Copy link
Copy Markdown
Collaborator

@mujtaba1747 mujtaba1747 May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is validate region scattered all over the place? Ideally it should only be checked before a request is sent ? Somewhere through the sagemaker client or in sagemaker core?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no single chokepoint where all region-to-URL paths converge. Many of these URLs never go through a SageMaker client at all — telemetry uses raw requests.get(), Studio/Console URLs are returned to users as browser links, ECR image URIs are just strings passed to Docker/SageMaker APIs, and STS endpoints are passed as endpoint_url to boto3 (which doesn't validate it points to AWS). Validating once in Session.init() would only cover paths that obtain region from the session, but not paths where region is extracted from untrusted ARN strings (e.g., _parse_job_arn()) or passed directly as a function parameter (e.g., image_uris.retrieve(region=...)). The validation is placed at URL construction sites because that's where the region value actually becomes dangerous, and it's the only approach that covers all paths. This is consistent with how CVE-2026-22611 was fixed in other AWS SDKs — validate at endpoint construction, not at a single entry point.

Copy link
Copy Markdown
Collaborator

@mujtaba1747 mujtaba1747 May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add region validation to the Pydantic middleware or create a new middleware for sanitizing user inputs.

A middleware could ensure the region validation happens every time without us having to add it explicitly for every call?

Or a new datatype AwsRegion that can be instantiated like a string but throws an exception if the regex fails? So all uses of region: Optional[str] would be replaced with region: Optional[AwsRegion]

Devs might skip/forget adding region check in the future as this is an ever changing codebase.

endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -359,6 +361,7 @@ def retrieve_pytorch_uri(
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
version_config = version_config.get(py_version) or version_config
registry = _registry_from_region(region, version_config["registries"])
validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -561,6 +564,7 @@ def retrieve(
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
version_config = version_config.get(py_version) or version_config
registry = _registry_from_region(region, version_config["registries"])
validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -623,6 +627,7 @@ def retrieve_base_python_image_uri(region: str, py_version: str = "310") -> str:

framework = "sagemaker-base-python"
version = "1.0"
validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ def _retrieve_latest_pytorch_training_uri(region: str):
version_config = config[image_scope]["versions"][latest_version]
py_version = _validate_py_version_and_set_if_needed(None, version_config, None)

from sagemaker.core.region_validation import validate_region

validate_region(region)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add validate region to _botocore_resolver() instead?

endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down
3 changes: 3 additions & 0 deletions sagemaker-core/src/sagemaker/core/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
from sagemaker.core.jumpstart.enums import JumpStartModelType
from sagemaker.core.jumpstart.utils import is_jumpstart_model_input
from sagemaker.core.region_validation import validate_region
from sagemaker.core.spark import defaults
from sagemaker.core.jumpstart import artifacts
from sagemaker.core.workflow import is_pipeline_variable
Expand Down Expand Up @@ -213,6 +214,7 @@ def retrieve(
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
version_config = version_config.get(py_version) or version_config
registry = _registry_from_region(region, version_config["registries"])
validate_region(region)
endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -749,6 +751,7 @@ def get_base_python_image_uri(region, py_version="310") -> str:

framework = "sagemaker-base-python"
version = "1.0"
validate_region(region)
endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(
one is created using the default AWS configuration chain.
Default: ``None``
"""
from sagemaker.core.region_validation import validate_region

if isinstance(region, str):
self.region = region
else:
Expand All @@ -55,6 +57,7 @@ def __init__(
" configuration."
)

validate_region(self.region)
self._sagemaker_client = boto3.client("sagemaker", region_name=self.region)
# Used to store domain and user profile info retrieved from Studio environment.
self._domain_id = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, region: Optional[str] = None):
region (str): The name of the region e.g. us-east-1. If not specified,
one is created using the default AWS configuration chain.
"""
from sagemaker.core.region_validation import validate_region

if region:
self.region = region
else:
Expand All @@ -49,6 +51,8 @@ def __init__(self, region: Optional[str] = None):
"as an input argument or setup the local AWS config."
)

validate_region(self.region)

self._domain_id = None
self._user_profile_name = None
self._valid_domain_and_user = False
Expand Down
90 changes: 90 additions & 0 deletions sagemaker-core/src/sagemaker/core/region_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Region validation utilities to prevent SSRF via malicious region strings.

This module provides validation for AWS region parameters before they are
interpolated into endpoint URLs. Without validation, a crafted region value
(e.g., ``x@attacker.com:443/#``) could redirect SDK API calls — including
SigV4-signed requests — to non-AWS hosts.

See: CVE-2026-22611 (AWS SDK for .NET, same vulnerability class).
"""
from __future__ import absolute_import

import re
from urllib.parse import urlparse

# Regex for valid AWS region names (e.g., us-east-1, eu-west-2, cn-north-1, us-gov-west-1).
# Uses \A and \Z anchors to prevent newline injection bypass that $ allows.
_VALID_REGION_PATTERN = re.compile(r"\A[a-z]{2}(-[a-z]+)+-\d+\Z")

# Trusted AWS domain suffixes for endpoint URL validation (defense-in-depth).
_AWS_DOMAINS = (
".amazonaws.com",
".amazonaws.com.cn",
".api.aws",
".sagemaker.aws",
)


class InvalidRegionError(ValueError):
"""Raised when an invalid AWS region string is provided.

This prevents SSRF attacks where a crafted region value
(e.g., ``x@attacker.com:443/#``) could redirect SDK API calls
to non-AWS hosts.
"""


def validate_region(region: str) -> str:
"""Validate that a region string is a well-formed AWS region name.

Args:
region: The region string to validate.

Returns:
The validated region string (unchanged).

Raises:
InvalidRegionError: If the region does not match the expected pattern.
"""
if not isinstance(region, str) or not _VALID_REGION_PATTERN.match(region):
raise InvalidRegionError(
f"Invalid AWS region: {region!r}. "
"Region must match pattern like 'us-east-1', 'eu-west-2', 'cn-north-1'."
)
return region


def validate_endpoint_url(url: str) -> str:
"""Validate that a constructed endpoint URL resolves to an AWS host.

This is a defense-in-depth check that catches URL manipulation even if
the region regex is somehow bypassed.

Args:
url: The constructed endpoint URL.

Returns:
The validated URL (unchanged).

Raises:
InvalidRegionError: If the URL hostname does not end with a trusted AWS domain.
"""
parsed = urlparse(url)
hostname = parsed.hostname or ""
if not any(hostname.endswith(d) for d in _AWS_DOMAINS):
raise InvalidRegionError(
f"Constructed endpoint resolves to non-AWS host: {hostname!r}"
)
return url
14 changes: 7 additions & 7 deletions sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

TEST_DOMAIN = "testdomain"
TEST_USER_PROFILE = "testuser"
TEST_REGION = "testregion"
TEST_REGION = "us-west-2"
TEST_NOTEBOOK_METADATA = json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE})
TEST_TRAINING_JOB = "testjob"

Expand Down Expand Up @@ -120,16 +120,16 @@ def test_detail_profiler_init_with_default_region():
"""
# happy case
with patch(
"sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock
) as region_mock:
region_mock.return_value = TEST_REGION
"sagemaker.core.interactive_apps.detail_profiler_app.Session"
) as session_mock:
session_mock.return_value.boto_region_name = TEST_REGION
detail_profiler_app = DetailProfilerApp()
assert detail_profiler_app.region == TEST_REGION

# no default region configured
with patch(
"sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock
) as region_mock:
region_mock.side_effect = [ValueError()]
"sagemaker.core.interactive_apps.detail_profiler_app.Session"
) as session_mock:
session_mock.side_effect = ValueError()
with pytest.raises(ValueError):
detail_profiler_app = DetailProfilerApp()
15 changes: 8 additions & 7 deletions sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

TEST_DOMAIN = "testdomain"
TEST_USER_PROFILE = "testuser"
TEST_REGION = "testregion"
TEST_REGION = "us-west-2"
TEST_NOTEBOOK_METADATA = json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE})
TEST_PRESIGNED_URL = (
f"https://{TEST_DOMAIN}.studio.{TEST_REGION}.sagemaker.aws/auth?token=FAKETOKEN"
Expand Down Expand Up @@ -824,16 +824,17 @@ def test_tb_init_with_default_region():
"""
# happy case
with patch(
"sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock
) as region_mock:
region_mock.return_value = TEST_REGION
"sagemaker.core.interactive_apps.base_interactive_app.Session"
) as session_mock:
session_mock.return_value.boto_region_name = TEST_REGION
tb_app = TensorBoardApp()
assert tb_app.region == TEST_REGION

# no default region configured
with patch(
"sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock
) as region_mock:
region_mock.side_effect = [ValueError()]
"sagemaker.core.interactive_apps.base_interactive_app.Session"
) as session_mock:
session_mock.return_value.boto_region_name = PropertyMock(side_effect=ValueError())
session_mock.side_effect = ValueError()
with pytest.raises(ValueError):
tb_app = TensorBoardApp()
Loading
Loading