Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ class ResourceRequest:
het_group_indices: Optional[list[int]] = None
segment: Optional[int] = None
network: Optional[str] = None
#: Template name to use for Ray jobs (e.g., "ray.sub.j2" or "ray_enroot.sub.j2")
ray_template: str = "ray.sub.j2"

#: Set by the executor; cannot be initialized
job_name: str = field(init=False, default="nemo-job")
Expand Down
11 changes: 8 additions & 3 deletions nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@
)
from torchx.specs.api import is_terminal

from nemo_run.config import RUNDIR_NAME, USE_WITH_RAY_CLUSTER_KEY, from_dict, get_nemorun_home
from nemo_run.config import (
RUNDIR_NAME,
USE_WITH_RAY_CLUSTER_KEY,
from_dict,
get_nemorun_home,
)
from nemo_run.core.execution.base import Executor
from nemo_run.core.execution.slurm import SlurmBatchRequest, SlurmExecutor, SlurmJobDetails
from nemo_run.core.tunnel.client import LocalTunnel, PackagingJob, SSHTunnel, Tunnel
Expand Down Expand Up @@ -125,8 +130,8 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
)

command = [app.roles[0].entrypoint] + app.roles[0].args
# Allow selecting Ray template via environment variable
ray_template_name = os.environ.get("NEMO_RUN_SLURM_RAY_TEMPLATE", "ray.sub.j2")
# Use Ray template from executor configuration
ray_template_name = executor.ray_template
req = SlurmRayRequest(
name=app.roles[0].name,
launch_cmd=["sbatch", "--requeue", "--parsable"],
Expand Down
24 changes: 15 additions & 9 deletions test/run/torchx_backend/schedulers/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def test_schedule_with_dependencies(slurm_scheduler, slurm_executor):
mock_tunnel.run.assert_called_once()


def test_ray_template_env_var(slurm_scheduler, slurm_executor):
"""Test that NEMO_RUN_SLURM_RAY_TEMPLATE environment variable selects the correct template."""
def test_ray_template_executor(slurm_scheduler, slurm_executor, temp_dir):
"""Test that executor.ray_template selects the correct template."""
from nemo_run.config import USE_WITH_RAY_CLUSTER_KEY
from nemo_run.run.ray.slurm import SlurmRayRequest

Expand All @@ -387,20 +387,26 @@ def test_ray_template_env_var(slurm_scheduler, slurm_executor):
):
slurm_scheduler.tunnel = mock.MagicMock()

# Test default template name
# Test default template name (ray.sub.j2)
assert slurm_executor.ray_template == "ray.sub.j2"
with mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill:
mock_fill.return_value = "#!/bin/bash\n# Mock script"
dryrun_info = slurm_scheduler._submit_dryrun(app_def, slurm_executor)
assert isinstance(dryrun_info.request, SlurmRayRequest)
assert dryrun_info.request.template_name == "ray.sub.j2"

# Test custom template name via environment variable
with (
mock.patch.dict(os.environ, {"NEMO_RUN_SLURM_RAY_TEMPLATE": "ray_enroot.sub.j2"}),
mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill,
):
# Test custom template name via executor
custom_executor = SlurmExecutor(
account="test_account",
job_dir=temp_dir,
nodes=1,
ntasks_per_node=1,
tunnel=LocalTunnel(job_dir=temp_dir),
ray_template="ray_enroot.sub.j2",
)
with mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill:
mock_fill.return_value = "#!/bin/bash\n# Mock script"
dryrun_info = slurm_scheduler._submit_dryrun(app_def, slurm_executor)
dryrun_info = slurm_scheduler._submit_dryrun(app_def, custom_executor)
assert isinstance(dryrun_info.request, SlurmRayRequest)
assert dryrun_info.request.template_name == "ray_enroot.sub.j2"

Expand Down
Loading