Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3621,6 +3621,12 @@ def from_jumpstart_config(
)
mb_instance.inference_ami_version = deploy_kwargs.get("inference_ami_version")

# Apply network isolation from JumpStart model spec if not set by user via network param
if not mb_instance._enable_network_isolation and deploy_kwargs.get(
"enable_network_isolation"
):
mb_instance._enable_network_isolation = deploy_kwargs["enable_network_isolation"]

return mb_instance

@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="model_builder.transformer")
Expand Down
7 changes: 7 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,13 @@ def _build_for_jumpstart(self) -> Model:
if hasattr(init_kwargs, "env") and init_kwargs.env:
self.env_vars.update(init_kwargs.env)

# Apply network isolation from JumpStart model spec if not already set by user
if (
not self._enable_network_isolation
and getattr(init_kwargs, "enable_network_isolation", None) is not None
):
self._enable_network_isolation = init_kwargs.enable_network_isolation

# Handle model artifacts for fine-tuned models
if hasattr(init_kwargs, "model_data") and init_kwargs.model_data:
if (
Expand Down
67 changes: 67 additions & 0 deletions sagemaker-serve/tests/integ/test_jumpstart_network_isolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import uuid
import pytest
import logging

from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.core.jumpstart.configs import JumpStartConfig
from sagemaker.train.configs import Compute

logger = logging.getLogger(__name__)

MODEL_ID = "huggingface-llm-falcon-7b-bf16"
INSTANCE_TYPE = "ml.g5.2xlarge"
MODEL_NAME_PREFIX = "js-netiso-test"


@pytest.mark.slow_test
def test_jumpstart_build_enables_network_isolation():
"""Integration test verifying JumpStart models are built with EnableNetworkIsolation.

JumpStart model specs define inference_enable_network_isolation=True for most models.
This test validates that ModelBuilder.build() propagates this setting to the
SageMaker Model resource, matching v2 behavior.
"""
logger.info("Starting JumpStart network isolation integration test...")

compute = Compute(instance_type=INSTANCE_TYPE)
jumpstart_config = JumpStartConfig(model_id=MODEL_ID)
model_builder = ModelBuilder.from_jumpstart_config(
jumpstart_config=jumpstart_config, compute=compute
)
unique_id = str(uuid.uuid4())[:8]

core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}")
logger.info(f"Model created: {core_model.model_name}")

try:
# Verify ModelBuilder picked up network isolation from spec
assert model_builder._enable_network_isolation, (
f"ModelBuilder._enable_network_isolation should be True for {MODEL_ID}, "
f"got {model_builder._enable_network_isolation}"
)

# Verify the actual SageMaker Model resource has EnableNetworkIsolation=True
sm_client = model_builder.sagemaker_session.sagemaker_client
desc = sm_client.describe_model(ModelName=core_model.model_name)
assert desc.get("EnableNetworkIsolation") is True, (
f"SageMaker Model should have EnableNetworkIsolation=True, "
f"got {desc.get('EnableNetworkIsolation')}"
)

logger.info("✅ Network isolation correctly applied to SageMaker Model")
finally:
core_model.delete()
logger.info("Model deleted.")
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self):
self.framework_version = None
self._is_mlflow_model = False
self.config_name = None
self._enable_network_isolation = False

def _deploy_local_endpoint(self, **kwargs):
return Mock()
Expand Down
28 changes: 28 additions & 0 deletions sagemaker-serve/tests/unit/test_model_builder_coverage_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,34 @@ def test_from_jumpstart_config_basic(self):
self.assertEqual(mb.model, "test-model")
self.assertEqual(mb.model_version, "1.0.0")

@patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs")
def test_from_jumpstart_config_applies_network_isolation(self, mock_deploy_kwargs):
"""Test that enable_network_isolation from deploy kwargs is applied."""
from sagemaker.core.jumpstart.configs import JumpStartConfig
from sagemaker.core.training.configs import Compute

mock_deploy_kwargs.return_value = {
"model_data_download_timeout": 600,
"enable_network_isolation": True,
}

js_config = JumpStartConfig(
model_id="test-model",
model_version="1.0.0"
)

mock_session = Mock()
mock_session.boto_region_name = "us-west-2"

mb = ModelBuilder.from_jumpstart_config(
jumpstart_config=js_config,
role_arn="arn:aws:iam::123456789012:role/SageMakerRole",
compute=Compute(instance_type="ml.g5.xlarge"),
sagemaker_session=mock_session,
)

self.assertTrue(mb._enable_network_isolation)


if __name__ == "__main__":
unittest.main()
65 changes: 65 additions & 0 deletions sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,71 @@ def test_build_for_jumpstart_routes_to_mms(self, mock_prepare, mock_create, mock
mock_create.assert_called_once()


@patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs")
@patch("sagemaker.serve.model_builder.ModelBuilder._create_model")
@patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode")
def test_build_for_jumpstart_applies_network_isolation_from_spec(
self, mock_prepare, mock_create, mock_get_kwargs
):
"""Test that enable_network_isolation from JumpStart model spec is applied."""
mock_init_kwargs = Mock()
mock_init_kwargs.image_uri = (
"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
)
mock_init_kwargs.env = {}
mock_init_kwargs.model_data = "s3://jumpstart-cache/models/model.tar.gz"
mock_init_kwargs.enable_network_isolation = True
mock_get_kwargs.return_value = mock_init_kwargs

mock_model = Mock(spec=Model)
mock_create.return_value = mock_model

builder = ModelBuilder(
model="meta-textgeneration-llama-3-8b",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
mode=Mode.SAGEMAKER_ENDPOINT,
)
builder._optimizing = False

builder._build_for_jumpstart()

self.assertTrue(builder._enable_network_isolation)

@patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs")
@patch("sagemaker.serve.model_builder.ModelBuilder._create_model")
@patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode")
def test_build_for_jumpstart_does_not_override_user_network_isolation(
self, mock_prepare, mock_create, mock_get_kwargs
):
"""Test that user-set network isolation is not overridden by spec."""
mock_init_kwargs = Mock()
mock_init_kwargs.image_uri = (
"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
)
mock_init_kwargs.env = {}
mock_init_kwargs.model_data = "s3://jumpstart-cache/models/model.tar.gz"
mock_init_kwargs.enable_network_isolation = False
mock_get_kwargs.return_value = mock_init_kwargs

mock_model = Mock(spec=Model)
mock_create.return_value = mock_model

builder = ModelBuilder(
model="meta-textgeneration-llama-3-8b",
role_arn=MOCK_ROLE_ARN,
sagemaker_session=self.mock_session,
mode=Mode.SAGEMAKER_ENDPOINT,
)
builder._optimizing = False
builder._enable_network_isolation = True # User explicitly set

builder._build_for_jumpstart()

# User's True should not be overridden by spec's False
self.assertTrue(builder._enable_network_isolation)


class TestDeployWrappers(unittest.TestCase):
"""Test deploy wrapper methods."""

Expand Down
Loading