Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sagemaker-core/src/sagemaker/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,8 @@
)
from sagemaker.core.transformer import Transformer # noqa: F401

# Partner App
from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401

# Note: HyperparameterTuner and WarmStartTypes are in sagemaker.train.tuner
# They are not re-exported from core to avoid circular dependencies
16 changes: 16 additions & 0 deletions sagemaker-core/src/sagemaker/core/partner_app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""__init__ file for sagemaker.core.partner_app"""
from __future__ import absolute_import

from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401
129 changes: 129 additions & 0 deletions sagemaker-core/src/sagemaker/core/partner_app/auth_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""The SageMaker partner application SDK auth module"""
from __future__ import absolute_import

import os
import re
from typing import Dict, Tuple

import boto3
from botocore.auth import SigV4Auth
from botocore.credentials import Credentials
from requests.auth import AuthBase
from requests.models import PreparedRequest
from sagemaker.core.partner_app.auth_utils import PartnerAppAuthUtils

SERVICE_NAME = "sagemaker"
AWS_PARTNER_APP_ARN_REGEX = r"arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:partner-app\/.*"


class RequestsAuth(AuthBase):
"""Requests authentication class for SigV4 header generation.

This class is used to generate the SigV4 header and add it to the request headers.
"""

def __init__(self, sigv4: SigV4Auth, app_arn: str):
"""Initialize the RequestsAuth class.

Args:
sigv4 (SigV4Auth): SigV4Auth object
app_arn (str): Application ARN
"""
self.sigv4 = sigv4
self.app_arn = app_arn

def __call__(self, request: PreparedRequest) -> PreparedRequest:
"""Callback function to generate the SigV4 header and add it to the request headers.

Args:
request (PreparedRequest): PreparedRequest object

Returns:
PreparedRequest: PreparedRequest object with the SigV4 header added
"""
url, signed_headers = PartnerAppAuthUtils.get_signed_request(
sigv4=self.sigv4,
app_arn=self.app_arn,
url=request.url,
method=request.method,
headers=request.headers,
body=request.body,
)
request.url = url
request.headers.update(signed_headers)

return request


class PartnerAppAuthProvider:
"""The SageMaker partner application SDK auth provider class"""

def __init__(self, credentials: Credentials = None):
"""Initialize the PartnerAppAuthProvider class.

Args:
credentials (Credentials, optional): AWS credentials. Defaults to None.
Raises:
ValueError: If the AWS_PARTNER_APP_ARN environment variable is not set or is invalid.
"""
self.app_arn = os.getenv("AWS_PARTNER_APP_ARN")
if self.app_arn is None:
raise ValueError("Must specify the AWS_PARTNER_APP_ARN environment variable")

app_arn_regex_match = re.search(AWS_PARTNER_APP_ARN_REGEX, self.app_arn)
if app_arn_regex_match is None:
raise ValueError("Must specify a valid AWS_PARTNER_APP_ARN environment variable")

split_arn = self.app_arn.split(":")
self.region = split_arn[3]

self.credentials = (
credentials if credentials is not None else boto3.Session().get_credentials()
)
self.sigv4 = SigV4Auth(self.credentials, SERVICE_NAME, self.region)

def get_signed_request(
self, url: str, method: str, headers: dict, body: object
) -> Tuple[str, Dict[str, str]]:
"""Generate the SigV4 header and add it to the request headers.

Args:
url (str): Request URL
method (str): HTTP method
headers (dict): Request headers
body (object): Request body

Returns:
tuple: (url, headers)
"""
return PartnerAppAuthUtils.get_signed_request(
sigv4=self.sigv4,
app_arn=self.app_arn,
url=url,
method=method,
headers=headers,
body=body,
)

def get_auth(self) -> RequestsAuth:
"""Returns the callback class (RequestsAuth) used for generating the SigV4 header.

Returns:
RequestsAuth: Callback Object which will calculate the header just before
request submission.
"""

return RequestsAuth(self.sigv4, os.environ["AWS_PARTNER_APP_ARN"])
122 changes: 122 additions & 0 deletions sagemaker-core/src/sagemaker/core/partner_app/auth_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

"""Partner App Auth Utils Module"""

from __future__ import absolute_import

from hashlib import sha256
import functools
from typing import Tuple, Dict

from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

HEADER_CONNECTION = "Connection"
HEADER_X_AMZ_TARGET = "X-Amz-Target"
HEADER_AUTHORIZATION = "Authorization"
HEADER_PARTNER_APP_SERVER_ARN = "X-SageMaker-Partner-App-Server-Arn"
HEADER_PARTNER_APP_AUTHORIZATION = "X-Amz-Partner-App-Authorization"
HEADER_X_AMZ_CONTENT_SHA_256 = "X-Amz-Content-SHA256"
CALL_PARTNER_APP_API_ACTION = "SageMaker.CallPartnerAppApi"

PAYLOAD_BUFFER = 1024 * 1024
EMPTY_SHA256_HASH = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD"


class PartnerAppAuthUtils:
"""Partner App Auth Utils Class"""

@staticmethod
def get_signed_request(
sigv4: SigV4Auth, app_arn: str, url: str, method: str, headers: dict, body: object
) -> Tuple[str, Dict[str, str]]:
"""Generate the SigV4 header and add it to the request headers.

Args:
sigv4 (SigV4Auth): SigV4Auth object
app_arn (str): Application ARN
url (str): Request URL
method (str): HTTP method
headers (dict): Request headers
body (object): Request body
Returns:
tuple: (url, headers)
"""
# Move API key to X-Amz-Partner-App-Authorization
if HEADER_AUTHORIZATION in headers:
headers[HEADER_PARTNER_APP_AUTHORIZATION] = headers[HEADER_AUTHORIZATION]

# App Arn
headers[HEADER_PARTNER_APP_SERVER_ARN] = app_arn

# IAM Action
headers[HEADER_X_AMZ_TARGET] = CALL_PARTNER_APP_API_ACTION

# Body
headers[HEADER_X_AMZ_CONTENT_SHA_256] = PartnerAppAuthUtils.get_body_header(body)

# Connection header is excluded from server-side signature calculation
connection_header = headers[HEADER_CONNECTION] if HEADER_CONNECTION in headers else None

if HEADER_CONNECTION in headers:
del headers[HEADER_CONNECTION]

# Spaces are encoded as %20
url = url.replace("+", "%20")

# Calculate SigV4 header
aws_request = AWSRequest(
method=method,
url=url,
headers=headers,
data=body,
)
sigv4.add_auth(aws_request)

# Reassemble headers
final_headers = dict(aws_request.headers.items())
if connection_header is not None:
final_headers[HEADER_CONNECTION] = connection_header

return (url, final_headers)

@staticmethod
def get_body_header(body: object):
"""Calculate the body header for the SigV4 header.

Args:
body (object): Request body
"""
if body and hasattr(body, "seek"):
position = body.tell()
read_chunksize = functools.partial(body.read, PAYLOAD_BUFFER)
checksum = sha256()
for chunk in iter(read_chunksize, b""):
checksum.update(chunk)
hex_checksum = checksum.hexdigest()
body.seek(position)
return hex_checksum

if body and not isinstance(body, bytes):
# Body is of a class we don't recognize, so don't sign the payload
return UNSIGNED_PAYLOAD

if body:
# The request serialization has ensured that
# request.body is a bytes() type.
return sha256(body).hexdigest()

# Body is None
return EMPTY_SHA256_HASH
13 changes: 13 additions & 0 deletions sagemaker-core/tests/unit/sagemaker/core/partner_app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
Loading