Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
import re
import shutil
import stat
import subprocess
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -58,15 +57,34 @@
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"


def _rmtree(path):
def _rmtree(path, image=None, is_studio=False):
"""Remove a directory tree, handling root-owned files from Docker containers."""
def _onerror(func, path, exc_info):
if isinstance(exc_info[1], PermissionError):
os.chmod(path, stat.S_IRWXU)
func(path)
else:
raise exc_info[1]
shutil.rmtree(path, onerror=_onerror)
try:
shutil.rmtree(path)
except PermissionError:
# Files created by Docker containers are owned by root.
# Use docker to chmod as root, then retry shutil.rmtree.
if image is None:
logger.warning(
"Failed to clean up root-owned files in %s. "
"You may need to remove them manually with: sudo rm -rf %s",
path, path,
)
raise
try:
cmd = ["docker", "run", "--rm"]
if is_studio:
cmd += ["--network", "sagemaker"]
cmd += ["-v", f"{path}:/delete", image, "chmod", "-R", "777", "/delete"]
subprocess.run(cmd, check=True, capture_output=True)
shutil.rmtree(path)
except Exception:
logger.warning(
"Failed to clean up root-owned files in %s. "
"You may need to remove them manually with: sudo rm -rf %s",
path, path,
)
raise


class _LocalContainer(BaseModel):
Expand Down Expand Up @@ -221,12 +239,12 @@ def train(
# Print our Job Complete line
logger.info("Local training job completed, output artifacts saved to %s", artifacts)

_rmtree(os.path.join(self.container_root, "input"))
_rmtree(os.path.join(self.container_root, "shared"))
_rmtree(os.path.join(self.container_root, "input"), self.image, self.is_studio)
_rmtree(os.path.join(self.container_root, "shared"), self.image, self.is_studio)
for host in self.hosts:
_rmtree(os.path.join(self.container_root, host))
_rmtree(os.path.join(self.container_root, host), self.image, self.is_studio)
for folder in self._temporary_folders:
_rmtree(os.path.join(self.container_root, folder))
_rmtree(os.path.join(self.container_root, folder), self.image, self.is_studio)
return artifacts

def retrieve_artifacts(
Expand Down
44 changes: 31 additions & 13 deletions sagemaker-train/src/sagemaker/train/local/local_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import os
import re
import shutil
import stat
import subprocess
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -66,15 +65,34 @@
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"


def _rmtree(path):
def _rmtree(path, image=None, is_studio=False):
"""Remove a directory tree, handling root-owned files from Docker containers."""
def _onerror(func, path, exc_info):
if isinstance(exc_info[1], PermissionError):
os.chmod(path, stat.S_IRWXU)
func(path)
else:
raise exc_info[1]
shutil.rmtree(path, onerror=_onerror)
try:
shutil.rmtree(path)
except PermissionError:
# Files created by Docker containers are owned by root.
# Use docker to chmod as root, then retry shutil.rmtree.
if image is None:
logger.warning(
"Failed to clean up root-owned files in %s. "
"You may need to remove them manually with: sudo rm -rf %s",
path, path,
)
raise
try:
cmd = ["docker", "run", "--rm"]
if is_studio:
cmd += ["--network", "sagemaker"]
cmd += ["-v", f"{path}:/delete", image, "chmod", "-R", "777", "/delete"]
subprocess.run(cmd, check=True, capture_output=True)
shutil.rmtree(path)
except Exception:
logger.warning(
"Failed to clean up root-owned files in %s. "
"You may need to remove them manually with: sudo rm -rf %s",
path, path,
)
raise


class _LocalContainer(BaseModel):
Expand Down Expand Up @@ -229,12 +247,12 @@ def train(
# Print our Job Complete line
logger.info("Local training job completed, output artifacts saved to %s", artifacts)

_rmtree(os.path.join(self.container_root, "input"))
_rmtree(os.path.join(self.container_root, "shared"))
_rmtree(os.path.join(self.container_root, "input"), self.image, self.is_studio)
_rmtree(os.path.join(self.container_root, "shared"), self.image, self.is_studio)
for host in self.hosts:
_rmtree(os.path.join(self.container_root, host))
_rmtree(os.path.join(self.container_root, host), self.image, self.is_studio)
for folder in self._temporary_folders:
_rmtree(os.path.join(self.container_root, folder))
_rmtree(os.path.join(self.container_root, folder), self.image, self.is_studio)
return artifacts

def retrieve_artifacts(
Expand Down
80 changes: 80 additions & 0 deletions sagemaker-train/tests/unit/train/local/test_local_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from unittest.mock import patch, call
import pytest

from sagemaker.train.local.local_container import _rmtree

IMAGE = "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.1-cpu-py310"


class TestRmtree:
"""Test cases for _rmtree function."""

@patch("sagemaker.train.local.local_container.shutil.rmtree")
def test_rmtree_success(self, mock_rmtree):
"""Normal case — shutil.rmtree succeeds."""
_rmtree("/tmp/test", IMAGE)
mock_rmtree.assert_called_once_with("/tmp/test")

@patch("sagemaker.train.local.local_container.shutil.rmtree")
@patch("sagemaker.train.local.local_container.subprocess.run")
def test_rmtree_permission_error_docker_chmod_fallback(self, mock_run, mock_rmtree):
"""PermissionError triggers docker chmod then retry."""
mock_rmtree.side_effect = [PermissionError("Permission denied"), None]

_rmtree("/tmp/test", IMAGE)

mock_run.assert_called_once_with(
["docker", "run", "--rm", "-v", "/tmp/test:/delete", IMAGE, "chmod", "-R", "777", "/delete"],
check=True,
capture_output=True,
)
assert mock_rmtree.call_count == 2

@patch("sagemaker.train.local.local_container.shutil.rmtree")
@patch("sagemaker.train.local.local_container.subprocess.run")
def test_rmtree_studio_adds_network(self, mock_run, mock_rmtree):
"""In Studio, docker run includes --network sagemaker."""
mock_rmtree.side_effect = [PermissionError("Permission denied"), None]

_rmtree("/tmp/test", IMAGE, is_studio=True)

mock_run.assert_called_once_with(
[
"docker", "run", "--rm",
"--network", "sagemaker",
"-v", "/tmp/test:/delete", IMAGE,
"chmod", "-R", "777", "/delete",
],
check=True,
capture_output=True,
)

@patch("sagemaker.train.local.local_container.shutil.rmtree")
@patch("sagemaker.train.local.local_container.subprocess.run")
def test_rmtree_docker_fallback_fails_raises(self, mock_run, mock_rmtree):
"""If docker fallback also fails, the exception propagates."""
mock_rmtree.side_effect = PermissionError("Permission denied")
mock_run.side_effect = Exception("docker failed")

with pytest.raises(Exception, match="docker failed"):
_rmtree("/tmp/test", IMAGE)

@patch("sagemaker.train.local.local_container.shutil.rmtree")
def test_rmtree_no_image_raises(self, mock_rmtree):
"""PermissionError without image raises immediately."""
mock_rmtree.side_effect = PermissionError("Permission denied")

with pytest.raises(PermissionError):
_rmtree("/tmp/test")
Loading