Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions sagemaker-mlops/src/sagemaker/mlops/workflow/emr_serverless_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for EMR Serverless workflow."""
from __future__ import absolute_import

from typing import Any, Dict, List, Union, Optional

from sagemaker.core.helper.pipeline_variable import RequestType
from sagemaker.core.workflow.properties import Properties
from sagemaker.mlops.workflow.retry import StepRetryPolicy
from sagemaker.mlops.workflow.step_collections import StepCollection
from sagemaker.mlops.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum, CacheConfig


class EMRServerlessJobConfig:
"""Config for EMR Serverless job."""

def __init__(
self,
job_driver: Dict,
execution_role_arn: str,
configuration_overrides: Optional[Dict] = None,
execution_timeout_minutes: Optional[int] = None,
name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
): # pylint: disable=too-many-positional-arguments
"""Create a definition for EMR Serverless job configuration.

Args:
job_driver (Dict): The job driver for the job run.
execution_role_arn (str): The execution role ARN for the job run.
configuration_overrides (Dict, optional): Configuration overrides for the job run.
execution_timeout_minutes (int, optional): The maximum duration for the job run.
name (str, optional): The optional job run name.
tags (Dict[str, str], optional): The tags assigned to the job run.
"""
self.job_driver = job_driver
self.execution_role_arn = execution_role_arn
self.configuration_overrides = configuration_overrides
self.execution_timeout_minutes = execution_timeout_minutes
self.name = name
self.tags = tags

def to_request(self, application_id: Optional[str] = None) -> RequestType:
"""Convert EMRServerlessJobConfig object to request dict."""
config = {"executionRoleArn": self.execution_role_arn, "jobDriver": self.job_driver}
if application_id is not None:
config["applicationId"] = application_id
if self.configuration_overrides is not None:
config["configurationOverrides"] = self.configuration_overrides
if self.execution_timeout_minutes is not None:
config["executionTimeoutMinutes"] = self.execution_timeout_minutes
if self.name is not None:
config["name"] = self.name
if self.tags is not None:
config["tags"] = self.tags
return config


ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG = (
"EMRServerlessStep {step_name} cannot have both application_id and application_config. "
"To use EMRServerlessStep with application_config, "
"application_id must be explicitly set to None."
)

ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG = (
"EMRServerlessStep {step_name} must have either application_id or application_config"
)


class EMRServerlessStep(ConfigurableRetryStep):
"""EMR Serverless step for workflow with configurable retry policies."""

def __init__(
self,
name: str,
display_name: str,
description: str,
job_config: EMRServerlessJobConfig,
application_id: Optional[str] = None,
application_config: Optional[Dict[str, Any]] = None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
cache_config: Optional[CacheConfig] = None,
retry_policies: Optional[List[StepRetryPolicy]] = None,
): # pylint: disable=too-many-positional-arguments
"""Constructs an `EMRServerlessStep`.

Args:
name (str): The name of the EMR Serverless step.
display_name (str): The display name of the EMR Serverless step.
description (str): The description of the EMR Serverless step.
job_config (EMRServerlessJobConfig): Job configuration for the EMR Serverless job.
application_id (str, optional): The ID of the existing EMR Serverless application.
application_config (Dict[str, Any], optional): Configuration for creating a new
EMR Serverless application.
depends_on (List[Union[str, Step, StepCollection]], optional): A list of
`Step`/`StepCollection` names or `Step` instances or `StepCollection` instances
that this `EMRServerlessStep` depends on.
cache_config (CacheConfig, optional): A `sagemaker.workflow.steps.CacheConfig` instance.
retry_policies (List[StepRetryPolicy], optional): A list of retry policies.
"""
super().__init__(
name=name,
step_type=StepTypeEnum.EMR_SERVERLESS,
display_name=display_name,
description=description,
depends_on=depends_on,
retry_policies=retry_policies,
)

if application_id is None and application_config is None:
raise ValueError(ERR_STR_WITHOUT_APP_ID_AND_APP_CONFIG.format(step_name=name))

if application_id is not None and application_config is not None:
raise ValueError(ERR_STR_WITH_BOTH_APP_ID_AND_APP_CONFIG.format(step_name=name))

emr_serverless_args = {
"ExecutionRoleArn": job_config.execution_role_arn, # Top-level role (used by backend)
"JobConfig": job_config.to_request(
application_id
), # Role also in JobConfig (structure requirement)
}

if application_id is not None:
emr_serverless_args["ApplicationId"] = application_id
elif application_config is not None:
emr_serverless_args["ApplicationConfig"] = application_config

self.args = emr_serverless_args
self.cache_config = cache_config

root_property = Properties(
step_name=name, step=self, shape_name="GetJobRunResponse", service_name="emr-serverless"
)
self._properties = root_property

@property
def arguments(self) -> RequestType:
"""The arguments dict that is used to call EMR Serverless APIs."""
return self.args

@property
def properties(self) -> RequestType:
"""A Properties object representing the EMR Serverless GetJobRunResponse model."""
return self._properties

def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration and retry policies."""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
return request_dict
23 changes: 13 additions & 10 deletions sagemaker-mlops/src/sagemaker/mlops/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import attr

from sagemaker.core.local.local_session import LocalSagemakerClient

# Primitive imports (stay in core)
from sagemaker.core.workflow.entities import Entity
from sagemaker.core.helper.pipeline_variable import RequestType
Expand All @@ -30,6 +31,7 @@
)
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.workflow.functions import Join, JsonGet

# Orchestration imports (now in mlops)
from sagemaker.mlops.workflow.retry import RetryPolicy
from sagemaker.core.workflow.step_outputs import StepOutput
Expand Down Expand Up @@ -57,6 +59,7 @@ class StepTypeEnum(Enum):
QUALITY_CHECK = "QualityCheck"
CLARIFY_CHECK = "ClarifyCheck"
EMR = "EMR"
EMR_SERVERLESS = "EMRServerless"
FAIL = "Fail"
AUTOML = "AutoML"

Expand Down Expand Up @@ -417,6 +420,7 @@ def __init__(

if step_args:
from sagemaker.core.workflow.utilities import validate_step_args_input

# Lazy import to avoid circular dependency
from sagemaker.train.model_trainer import ModelTrainer

Expand All @@ -436,7 +440,7 @@ def __init__(
def arguments(self) -> RequestType:
"""The arguments dictionary that is used to call `create_training_job`.

NOTE: The `CreateTrainingJob` request is not quite the args list that workflow needs.
NOTE: The `CreateTrainingJob` request is not quite the args list that workflow needs.
"""
from sagemaker.core.workflow.utilities import execute_job_functions
from sagemaker.core.workflow.utilities import _pipeline_config
Expand All @@ -451,7 +455,7 @@ def arguments(self) -> RequestType:
request_dict = model_trainer.sagemaker_session.context.args
else:
raise ValueError("step_args input is required.")

if "HyperParameters" in request_dict:
request_dict["HyperParameters"].pop("sagemaker_job_name", None)

Expand Down Expand Up @@ -606,11 +610,13 @@ def __init__(
raise ValueError("step_args is required for ProcessingStep.")

from sagemaker.core.workflow.utilities import validate_step_args_input


validate_step_args_input(
step_args=step_args,
expected_caller={Processor.run.__name__, LocalSagemakerClient().create_processing_job.__name__},
expected_caller={
Processor.run.__name__,
LocalSagemakerClient().create_processing_job.__name__,
},
error_message=f"The step_args of ProcessingStep must be obtained from processor.run() or in local mode, not {step_args.caller_name}",
)

Expand Down Expand Up @@ -638,7 +644,7 @@ def arguments(self) -> RequestType:
# populate request dict with args
processor = self.step_args.func_args[0]
request_dict = processor.sagemaker_session.context.args

# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "ProcessingJobName", _pipeline_config)

Expand All @@ -663,8 +669,6 @@ def to_request(self) -> RequestType:
return request_dict




class TuningStep(ConfigurableRetryStep):
"""`TuningStep` for SageMaker Pipelines Workflows."""

Expand Down Expand Up @@ -698,7 +702,7 @@ def __init__(
name, StepTypeEnum.TUNING, display_name, description, depends_on, retry_policies
)

if not step_args :
if not step_args:
raise ValueError("step_args is required for TuningStep.")

from sagemaker.core.workflow.utilities import validate_step_args_input
Expand Down Expand Up @@ -737,7 +741,6 @@ def arguments(self) -> RequestType:
# populate request dict with args
tuner = self.step_args.func_args[0]
request_dict = tuner.sagemaker_session.context.args


# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(
Expand Down Expand Up @@ -785,4 +788,4 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
self.properties.TrainingJobSummaries[top_k].TrainingJobName,
"output/model.tar.gz",
],
)
)
Loading
Loading