Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@


class ImageRetriever:
_config = SageMakerConfig()

@staticmethod
def retrieve_hugging_face_uri(
region: str,
Expand Down Expand Up @@ -110,7 +112,7 @@ def retrieve_hugging_face_uri(
args = dict(locals())
for name, val in args.items():
if name in CONFIGURABLE_ATTRIBUTES and not val:
default_value = SageMakerConfig.resolve_value_from_config(
default_value = ImageRetriever._config.resolve_value_from_config(
config_path=_simple_path(
SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
)
Expand Down Expand Up @@ -499,7 +501,7 @@ def retrieve(
args = dict(locals())
for name, val in args.items():
if name in CONFIGURABLE_ATTRIBUTES and not val:
default_value = SageMakerConfig.resolve_value_from_config(
default_value = ImageRetriever._config.resolve_value_from_config(
config_path=_simple_path(
SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
REMOTE_FUNCTION_WORKSPACE = "sm_rf_user_ws"
BASE_CHANNEL_PATH = "/opt/ml/input/data"
FAILURE_REASON_PATH = "/opt/ml/output/failure"
JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp"]
JOB_OUTPUT_DIRS = ["/opt/ml/input", "/opt/ml/output", "/opt/ml/model", "/tmp", "/opt/ml/checkpoints"]
PRE_EXECUTION_SCRIPT_NAME = "pre_exec.sh"
JOB_REMOTE_FUNCTION_WORKSPACE = "sagemaker_remote_function_workspace"
SCRIPT_AND_DEPENDENCIES_CHANNEL_NAME = "pre_exec_script_and_dependencies"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sagemaker.core.config.config_manager import SageMakerConfig


@pytest.mark.skip("Disabling this for now, Need to be fixed")
@pytest.mark.integ
def test_retrieve_image_uri():
image_uri = ImageRetriever.retrieve("clarify", "us-west-2")
Expand Down Expand Up @@ -56,7 +55,6 @@ def test_retrieve_image_uri():
)


@pytest.mark.skip("Disabling this for now, Need to be fixed")
@pytest.mark.integ
def test_retrieve_pytorch_uri():
image_uri = ImageRetriever.retrieve_pytorch_uri(
Expand All @@ -72,7 +70,6 @@ def test_retrieve_pytorch_uri():
)


@pytest.mark.skip("Disabling this for now, Need to be fixed")
@pytest.mark.integ
def test_retrieve_hugging_face_uri():
image_uri = ImageRetriever.retrieve_hugging_face_uri(
Expand All @@ -88,18 +85,16 @@ def test_retrieve_hugging_face_uri():
":2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"


@pytest.mark.skip("Disabling this for now, Need to be fixed")
@pytest.mark.integ
def test_retrieve_base_python_image_uri():
image_uri = ImageRetriever.retrieve_base_python_image_uri()
image_uri = ImageRetriever.retrieve_base_python_image_uri(region="us-west-2")
assert image_uri == "236514542706.dkr.ecr.us-west-2.amazonaws.com/sagemaker-base-python-310:1.0"


@pytest.mark.skip("Disabling this for now, Need to be fixed")
@pytest.mark.integ
@patch.object(SageMakerConfig, "resolve_value_from_config")
def test_retrieve_image_uri_intelligent_default(mock_load_config):
def custom_return(config_path):
def custom_return(config_path=None, **kwargs):
if config_path == _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, "ImageScope"
):
Expand Down
4 changes: 0 additions & 4 deletions sagemaker-core/tests/integ/remote_function/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ def divide(x, y):
divide(10, 2)


# TODO: add VPC settings, update SageMakerRole with KMS permissions
@pytest.mark.skip
def test_advanced_job_setting(
sagemaker_session, dummy_container_without_error, cpu_instance_type, s3_kms_key
):
Expand Down Expand Up @@ -552,7 +550,6 @@ def my_func():
assert client_error_message in str(error)


@pytest.mark.skip
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
@remote(
role=ROLE,
Expand All @@ -578,7 +575,6 @@ def test_spark_transform():
test_spark_transform()


@pytest.mark.skip
def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container):
"""
This test runs a docker container. The Container invocation will execute a python script
Expand Down
9 changes: 0 additions & 9 deletions sagemaker-train/tests/integ/train/test_benchmark_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
}


@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
class TestBenchmarkEvaluatorIntegration:
"""Integration tests for BenchmarkEvaluator with fine-tuned model package"""

Expand Down Expand Up @@ -286,16 +285,12 @@ def test_benchmark_subtasks_validation(self):

logger.info("Subtask validation tests passed")

@pytest.mark.skip(reason="Base model only evaluation - to be enabled when needed")
def test_benchmark_evaluation_base_model_only(self):
"""
Test benchmark evaluation with base model only (no fine-tuned model).

This test uses a JumpStart model ID directly instead of a model package ARN.
Configuration from commented section in benchmark_demo.ipynb.

Note: This test is currently skipped. Remove the @pytest.mark.skip decorator
when you want to enable it.
"""
# Get benchmarks
Benchmark = get_benchmarks()
Expand Down Expand Up @@ -339,16 +334,12 @@ def test_benchmark_evaluation_base_model_only(self):
assert execution.status.overall_status == "Succeeded"
logger.info("Base model only evaluation completed successfully")

@pytest.mark.skip(reason="Nova model evaluation - to be enabled when needed")
def test_benchmark_evaluation_nova_model(self):
"""
Test benchmark evaluation with Nova model.

This test uses a Nova fine-tuned model package in us-east-1 region.
Configuration from commented section in benchmark_demo.ipynb.

Note: This test is currently skipped. Remove the @pytest.mark.skip decorator
when you want to enable it.
"""
# Get benchmarks
Benchmark = get_benchmarks()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
}


@pytest.mark.skip(reason="Temporarily skipped - moved from tests/integ/sagemaker/modules/evaluate/")
class TestCustomScorerEvaluatorIntegration:
"""Integration tests for CustomScorerEvaluator with custom evaluator"""

Expand Down Expand Up @@ -233,16 +232,12 @@ def test_custom_scorer_evaluator_validation(self):

logger.info("Validation tests passed")

@pytest.mark.skip(reason="Built-in metric evaluation - to be enabled when needed")
def test_custom_scorer_with_builtin_metric(self):
"""
Test custom scorer evaluation with built-in metric.

This test uses a built-in metric (PRIME_MATH) instead of a custom evaluator ARN.
Configuration adapted from commented section in custom_scorer_demo.ipynb.

Note: This test is currently skipped. Remove the @pytest.mark.skip decorator
when you want to enable it.
"""
# Get built-in metrics
BuiltInMetric = get_builtin_metrics()
Expand Down Expand Up @@ -285,7 +280,6 @@ def test_custom_scorer_with_builtin_metric(self):
assert execution.status.overall_status == "Succeeded"
logger.info("Built-in metric evaluation completed successfully")

@pytest.mark.skip(reason="Base model only evaluation - not working yet per notebook")
def test_custom_scorer_base_model_only(self):
"""
Test custom scorer evaluation with base model only (no fine-tuned model).
Expand Down
188 changes: 116 additions & 72 deletions sagemaker-train/tests/integ/train/test_dpo_trainer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,87 +14,131 @@
from __future__ import absolute_import

import time
import logging
import traceback
import random
import boto3
from sagemaker.core.helper.session_helper import Session
from sagemaker.train.dpo_trainer import DPOTrainer
from sagemaker.train.common import TrainingType
import pytest

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


@pytest.mark.skip(reason="Skipping GPU resource intensive test")
def test_dpo_trainer_lora_complete_workflow(sagemaker_session):
"""Test complete DPO training workflow with LORA."""
# Create DPOTrainer instance with comprehensive configuration
trainer = DPOTrainer(
model="meta-textgeneration-llama-3-2-1b-instruct",
training_type=TrainingType.LORA,
model_package_group="sdk-test-finetuned-models",
training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1",
s3_output_path="s3://mc-flows-sdk-testing/output/",
accept_eula=True
)

# Customize hyperparameters for quick training
trainer.hyperparameters.max_epochs = 1

# Create training job
training_job = trainer.train(wait=False)

# Manual wait loop to avoid resource_config issue
max_wait_time = 3600 # 1 hour timeout
poll_interval = 30 # Check every 30 seconds
start_time = time.time()

while time.time() - start_time < max_wait_time:
training_job.refresh()
status = training_job.training_job_status

if status in ["Completed", "Failed", "Stopped"]:
break

time.sleep(poll_interval)

# Verify job completed successfully
assert training_job.training_job_status == "Completed"
assert hasattr(training_job, 'output_model_package_arn')
assert training_job.output_model_package_arn is not None


@pytest.mark.skip(reason="Skipping GPU resource intensive test")
logger.info("=== START test_dpo_trainer_lora_complete_workflow ===")
logger.info(f"sagemaker_session region: {sagemaker_session.boto_region_name}")

try:
# Create DPOTrainer instance with comprehensive configuration
trainer = DPOTrainer(
model="meta-textgeneration-llama-3-2-1b-instruct",
training_type=TrainingType.LORA,
model_package_group="sdk-test-finetuned-models",
training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1",
s3_output_path="s3://mc-flows-sdk-testing/output/",
accept_eula=True
)
logger.info(f"DPOTrainer created: model={trainer.model}, training_type={trainer.training_type}")

# Customize hyperparameters for quick training
trainer.hyperparameters.max_epochs = 1
logger.info(f"Set max_epochs=1")

# Create training job
logger.info("Calling trainer.train(wait=False)...")
training_job = trainer.train(wait=False)
logger.info(f"Training job created: {training_job}")

# Manual wait loop to avoid resource_config issue
max_wait_time = 3600 # 1 hour timeout
poll_interval = 30 # Check every 30 seconds
start_time = time.time()

while time.time() - start_time < max_wait_time:
training_job.refresh()
status = training_job.training_job_status
elapsed = int(time.time() - start_time)
logger.info(f"[{elapsed}s] Training job status: {status}")

if status in ["Completed", "Failed", "Stopped"]:
break

time.sleep(poll_interval)

logger.info(f"Final training job status: {training_job.training_job_status}")
if training_job.training_job_status == "Failed":
failure_reason = getattr(training_job, 'failure_reason', 'N/A')
logger.error(f"Training job FAILED. Failure reason: {failure_reason}")

# Verify job completed successfully
assert training_job.training_job_status == "Completed"
assert hasattr(training_job, 'output_model_package_arn')
assert training_job.output_model_package_arn is not None
logger.info(f"output_model_package_arn: {training_job.output_model_package_arn}")
except Exception as e:
logger.error(f"test_dpo_trainer_lora_complete_workflow FAILED: {type(e).__name__}: {e}")
logger.error(traceback.format_exc())
raise
logger.info("=== END test_dpo_trainer_lora_complete_workflow - PASSED ===")


def test_dpo_trainer_with_validation_dataset(sagemaker_session):
"""Test DPO trainer with both training and validation datasets."""

dpo_trainer = DPOTrainer(
model="meta-textgeneration-llama-3-2-1b-instruct",
training_type=TrainingType.LORA,
model_package_group="sdk-test-finetuned-models",
training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1",
validation_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1",
s3_output_path="s3://mc-flows-sdk-testing/output/",
accept_eula=True
)

# Customize hyperparameters for quick training
dpo_trainer.hyperparameters.max_epochs = 1

training_job = dpo_trainer.train(wait=False)

# Manual wait loop
max_wait_time = 3600
poll_interval = 30
start_time = time.time()

while time.time() - start_time < max_wait_time:
training_job.refresh()
status = training_job.training_job_status

if status in ["Completed", "Failed", "Stopped"]:
break

time.sleep(poll_interval)

# Verify job completed successfully
assert training_job.training_job_status == "Completed"
assert hasattr(training_job, 'output_model_package_arn')
assert training_job.output_model_package_arn is not None
logger.info("=== START test_dpo_trainer_with_validation_dataset ===")
logger.info(f"sagemaker_session region: {sagemaker_session.boto_region_name}")

try:
dpo_trainer = DPOTrainer(
model="meta-textgeneration-llama-3-2-1b-instruct",
training_type=TrainingType.LORA,
model_package_group="sdk-test-finetuned-models",
training_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1",
validation_dataset="arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/DataSet/dpo-oss-test-data/0.0.1",
s3_output_path="s3://mc-flows-sdk-testing/output/",
accept_eula=True
)
logger.info(f"DPOTrainer created with validation dataset")

# Customize hyperparameters for quick training
dpo_trainer.hyperparameters.max_epochs = 1
logger.info(f"Set max_epochs=1")

logger.info("Calling dpo_trainer.train(wait=False)...")
training_job = dpo_trainer.train(wait=False)
logger.info(f"Training job created: {training_job}")

# Manual wait loop
max_wait_time = 3600
poll_interval = 30
start_time = time.time()

while time.time() - start_time < max_wait_time:
training_job.refresh()
status = training_job.training_job_status
elapsed = int(time.time() - start_time)
logger.info(f"[{elapsed}s] Training job status: {status}")

if status in ["Completed", "Failed", "Stopped"]:
break

time.sleep(poll_interval)

logger.info(f"Final training job status: {training_job.training_job_status}")
if training_job.training_job_status == "Failed":
failure_reason = getattr(training_job, 'failure_reason', 'N/A')
logger.error(f"Training job FAILED. Failure reason: {failure_reason}")

# Verify job completed successfully
assert training_job.training_job_status == "Completed"
assert hasattr(training_job, 'output_model_package_arn')
assert training_job.output_model_package_arn is not None
logger.info(f"output_model_package_arn: {training_job.output_model_package_arn}")
except Exception as e:
logger.error(f"test_dpo_trainer_with_validation_dataset FAILED: {type(e).__name__}: {e}")
logger.error(traceback.format_exc())
raise
logger.info("=== END test_dpo_trainer_with_validation_dataset - PASSED ===")
Loading
Loading