Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/cloudai/systems/slurm/single_sbatch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,10 @@ def unroll_dse(self, tr: TestRun) -> Generator[TestRun, None, None]:
yield next_tr

def get_global_env_vars(self) -> str:
vars: list[str] = ["export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)"]
tr = self.test_scenario.test_runs[0]
cmd_gen = cast(SlurmCommandGenStrategy, self.get_cmd_gen_strategy(self.system, tr))
for key, value in cmd_gen.final_env_vars.items():
vars.append(f"export {key}={value}")
return "\n".join(vars)
env_vars = cmd_gen.get_sbatch_env_vars() | cmd_gen.final_env_vars
return "\n".join([f"export {key}={value}" for key, value in env_vars.items()])

def gen_sbatch_content(self) -> str:
content: list[str] = ["#!/bin/bash", *self.get_sbatch_directives(), ""]
Expand Down
122 changes: 72 additions & 50 deletions src/cloudai/systems/slurm/slurm_command_gen_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def nodelist_in_use(self) -> bool:
_, nodes = self.get_cached_nodes_spec()
return len(nodes) > 0

def get_sbatch_env_vars(self) -> dict[str, str]:
env_vars = {
"SLURM_JOB_MASTER_NODE": "$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)",
}

_, hostfile = self.get_nodes_related_directives()
if hostfile is not None:
env_vars["SLURM_HOSTFILE"] = str(hostfile)

return env_vars

def write_env_vars(self):
with (self.test_run.output_path / "env_vars.sh").open("w") as f:
for key, value in self.final_env_vars.items():
f.write(f'export {key}="{value}"\n')

@abstractmethod
def _container_mounts(self) -> list[str]:
"""Return CommandGenStrategy specific container mounts for the test run."""
Expand Down Expand Up @@ -231,9 +247,7 @@ def _gen_srun_command(self) -> str:
nsys_command_parts = self.gen_nsys_command()
test_command_parts = self.generate_test_command()

with (self.test_run.output_path / "env_vars.sh").open("w") as f:
for key, value in self.final_env_vars.items():
f.write(f'export {key}="{value}"\n')
self.write_env_vars()

full_test_cmd = (
f'bash -c "source {(self.test_run.output_path / "env_vars.sh").absolute()}; '
Expand Down Expand Up @@ -274,22 +288,10 @@ def gen_srun_prefix(self, use_pretest_extras: bool = False, with_num_nodes: bool
def generate_test_command(self) -> List[str]:
return []

def _add_reservation(self, batch_script_content: List[str]):
"""
Add reservation if provided.

Args:
batch_script_content (List[str]): content of the batch script.

Returns:
List[str]: updated batch script with reservation if exists.
"""
reservation_key = "--reservation "
if self.system.extra_srun_args and reservation_key in self.system.extra_srun_args:
reservation = self.system.extra_srun_args.split(reservation_key, 1)[1].split(" ", 1)[0]
batch_script_content.append(f"#SBATCH --reservation={reservation}")

return batch_script_content
def _get_reservation(self) -> str | None:
if self.system.extra_srun_args and "--reservation " in self.system.extra_srun_args:
return self.system.extra_srun_args.split("--reservation ", 1)[1].split(" ", 1)[0]
return None

def _ranks_mapping_cmd(self) -> str:
return " ".join(
Expand Down Expand Up @@ -352,6 +354,11 @@ def _write_sbatch_script(self, srun_command: str) -> str:
]

self._append_sbatch_directives(batch_script_content)
batch_script_content.append("")
batch_script_content.extend([self._format_env_vars(self.get_sbatch_env_vars())])

if sbatch_prefix := self._gen_sbatch_prefix():
batch_script_content.extend(sbatch_prefix)

batch_script_content.extend([self._format_env_vars(self.final_env_vars)])

Expand All @@ -368,50 +375,65 @@ def _write_sbatch_script(self, srun_command: str) -> str:

return f"sbatch {batch_script_path}"

def _append_sbatch_directives(self, batch_script_content: List[str]) -> None:
"""
Append SBATCH directives to the batch script content.
def _get_sbatch_directives(self) -> dict[str, str]:
directives = {}

Args:
batch_script_content (List[str]): The list of script lines to append to.
"""
batch_script_content = self._add_reservation(batch_script_content)
if reservation := self._get_reservation():
directives["reservation"] = reservation

directives["output"] = self.test_run.output_path.absolute() / "stdout.txt"
directives["error"] = self.test_run.output_path.absolute() / "stderr.txt"
directives["partition"] = self.system.default_partition

batch_script_content.append(f"#SBATCH --output={self.test_run.output_path.absolute() / 'stdout.txt'}")
batch_script_content.append(f"#SBATCH --error={self.test_run.output_path.absolute() / 'stderr.txt'}")
batch_script_content.append(f"#SBATCH --partition={self.system.default_partition}")
if self.system.account:
batch_script_content.append(f"#SBATCH --account={self.system.account}")
directives["account"] = self.system.account

hostfile = self._append_nodes_related_directives(batch_script_content)
if self.system.distribution:
directives["distribution"] = self.system.distribution

directives.update(self.get_nodes_related_directives()[0])

if self.system.gpus_per_node and self.system.supports_gpu_directives:
batch_script_content.append(f"#SBATCH --gpus-per-node={self.system.gpus_per_node}")
batch_script_content.append(f"#SBATCH --gres=gpu:{self.system.gpus_per_node}")
directives["gpus-per-node"] = self.system.gpus_per_node
directives["gres"] = f"gpu:{self.system.gpus_per_node}"

if self.system.ntasks_per_node:
batch_script_content.append(f"#SBATCH --ntasks-per-node={self.system.ntasks_per_node}")
directives["ntasks_per_node"] = self.system.ntasks_per_node

if self.test_run.time_limit:
batch_script_content.append(f"#SBATCH --time={self.test_run.time_limit}")
directives["time"] = self.test_run.time_limit

for arg in self.system.extra_sbatch_args:
batch_script_content.append(f"#SBATCH {arg}")
directives[arg] = ""

if hostfile is not None:
batch_script_content.append(f"export SLURM_HOSTFILE={hostfile}")
return directives

batch_script_content.append(
"\nexport SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)"
)
def _append_sbatch_directives(self, batch_script_content: List[str]) -> None:
"""
Append SBATCH directives to the batch script content.

def _append_nodes_related_directives(self, content: List[str]) -> Optional[Path]:
num_nodes, node_list = self.get_cached_nodes_spec()
Args:
batch_script_content (List[str]): The list of script lines to append to.
"""
directives = self._get_sbatch_directives()
for key, value in directives.items():
if key.startswith("-"):
# strip makes handling empty `value` cleaner
batch_script_content.append(f"#SBATCH {key} {value}".strip())
elif value:
batch_script_content.append(f"#SBATCH --{key}={value}")
else:
batch_script_content.append(f"#SBATCH --{key}")

def _gen_sbatch_prefix(self) -> list[str]:
return []

if self.system.distribution:
content.append(f"#SBATCH --distribution={self.system.distribution}")
def get_nodes_related_directives(self) -> tuple[dict, Optional[Path]]:
directives = {}
num_nodes, node_list = self.get_cached_nodes_spec()

if node_list:
content.append(f"#SBATCH --nodelist={','.join(node_list)}")
directives["nodelist"] = ",".join(node_list)

hostfile = (self.test_run.output_path / "hostfile.txt").absolute()
with hostfile.open("w") as hf:
Expand All @@ -420,14 +442,14 @@ def _append_nodes_related_directives(self, content: List[str]) -> Optional[Path]
for _ in range(tasks):
hf.write(f"{node}\n")

return hostfile
return directives, hostfile

content.append(f"#SBATCH -N {num_nodes}")
directives["-N"] = num_nodes

if self.test_run.exclude_nodes:
content.append(f"#SBATCH --exclude={','.join(self.test_run.exclude_nodes)}")
directives["exclude"] = ",".join(self.test_run.exclude_nodes)

return None
return directives, None

def _format_env_vars(self, env_vars: Dict[str, Any]) -> str:
"""
Expand Down
1 change: 0 additions & 1 deletion src/cloudai/workloads/common/nixl.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def final_env_vars(self) -> dict[str, str | list[str]]:
env_vars = super().final_env_vars
env_vars["NIXL_ETCD_NAMESPACE"] = "/nixl/kvbench/$(uuidgen)"
env_vars["NIXL_ETCD_ENDPOINTS"] = '"$SLURM_JOB_MASTER_NODE:2379"'
env_vars["SLURM_JOB_MASTER_NODE"] = "$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)"
return env_vars

@final_env_vars.setter
Expand Down
46 changes: 15 additions & 31 deletions src/cloudai/workloads/deepep/slurm_command_gen_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,21 @@
class DeepEPSlurmCommandGenStrategy(SlurmCommandGenStrategy):
"""Command generation strategy for DeepEP benchmark on Slurm systems."""

def _append_head_node_detection(self, batch_script_content: List[str]) -> None:
"""
Append bash commands to detect head node IP for torchrun.

Args:
batch_script_content: The list of script lines to append to.
"""
batch_script_content.extend(
[
"",
"nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )",
"nodes_array=($nodes)",
"head_node=${nodes_array[0]}",
'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)',
"",
"echo Nodes: $SLURM_JOB_NODELIST",
"echo Num Nodes: ${#nodes[@]}",
"echo Head Node IP: $head_node_ip",
"",
]
)

def _append_sbatch_directives(self, batch_script_content: List[str]) -> None:
"""
Append SBATCH directives and head node detection setup for DeepEP.

Args:
batch_script_content: The list of script lines to append to.
"""
super()._append_sbatch_directives(batch_script_content)
self._append_head_node_detection(batch_script_content)
def _gen_sbatch_prefix(self) -> list[str]:
"""Append bash commands to detect head node IP for torchrun."""
return [
*super()._gen_sbatch_prefix(),
"",
"nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )",
"nodes_array=($nodes)",
"head_node=${nodes_array[0]}",
'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)',
"",
"echo Nodes: $SLURM_JOB_NODELIST",
"echo Num Nodes: ${#nodes[@]}",
"echo Head Node IP: $head_node_ip",
"",
]

def _container_mounts(self) -> List[str]:
"""Return container mounts specific to DeepEP benchmark."""
Expand Down
Loading
Loading