Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions sagemaker-core/src/sagemaker/core/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sagemaker.core.config.config_schema import CONTAINER_CONFIG, LOCAL
import sagemaker.core
from sagemaker.core.common_utils import custom_extractall_tarfile
from sagemaker.serve.model_builder import DIR_PARAM_NAME, SAGEMAKER_OUTPUT_LOCATION

CONTAINER_PREFIX = "algo"
STUDIO_HOST_NAME = "sagemaker-local"
Expand Down Expand Up @@ -277,7 +278,7 @@ def train(self, input_data_config, output_data_config, hyperparameters, environm
)
# If local, source directory needs to be updated to mounted /opt/ml/code path
hyperparameters = self._update_local_src_path(
hyperparameters, key=sagemaker.serve.model_builder.DIR_PARAM_NAME
hyperparameters, key=DIR_PARAM_NAME
)

# Create the configuration files for each container that we will create
Expand Down Expand Up @@ -343,15 +344,15 @@ def serve(self, model_dir, environment):
volumes = self._prepare_serving_volumes(model_dir)

# If the user script was passed as a file:// mount it to the container.
if sagemaker.serve.model_builder.DIR_PARAM_NAME.upper() in environment:
script_dir = environment[sagemaker.serve.model_builder.DIR_PARAM_NAME.upper()]
if DIR_PARAM_NAME.upper() in environment:
script_dir = environment[DIR_PARAM_NAME.upper()]
parsed_uri = urlparse(script_dir)
if parsed_uri.scheme == "file":
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
volumes.append(_Volume(host_dir, "/opt/ml/code"))
# Update path to mount location
environment = environment.copy()
environment[sagemaker.serve.model_builder.DIR_PARAM_NAME.upper()] = "/opt/ml/code"
environment[DIR_PARAM_NAME.upper()] = "/opt/ml/code"

if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image):
_pull_image(self.image)
Expand Down Expand Up @@ -582,8 +583,8 @@ def _prepare_training_volumes(

# If there is a training script directory and it is a local directory,
# mount it to the container.
if sagemaker.serve.model_builder.DIR_PARAM_NAME in hyperparameters:
training_dir = json.loads(hyperparameters[sagemaker.serve.model_builder.DIR_PARAM_NAME])
if DIR_PARAM_NAME in hyperparameters:
training_dir = json.loads(hyperparameters[DIR_PARAM_NAME])
parsed_uri = urlparse(training_dir)
if parsed_uri.scheme == "file":
host_dir = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
Expand All @@ -594,7 +595,7 @@ def _prepare_training_volumes(
parsed_uri = urlparse(output_data_config["S3OutputPath"])
if (
parsed_uri.scheme == "file"
and sagemaker.serve.model_builder.SAGEMAKER_OUTPUT_LOCATION in hyperparameters
and SAGEMAKER_OUTPUT_LOCATION in hyperparameters
):
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
intermediate_dir = os.path.join(dir_path, "output", "intermediate")
Expand Down
Loading