Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 110 additions & 30 deletions src/cloudai/workloads/nixl_ep/slurm_command_gen_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from .nixl_ep import GENERATED_PLAN_FILE_NAME, NixlEPCmdArgs, NixlEPTestDefinition

LAUNCHER_SCRIPT_FILE_NAME = "nixl-ep-launch.sh"


@dataclass(frozen=True)
class NixlEPLaunch:
Expand Down Expand Up @@ -64,31 +66,25 @@ def num_processes_per_node(self) -> int:
raise ValueError("NIXL EP Slurm command generation requires num_processes_per_node to be an integer.")
return num_processes_per_node

def _append_sbatch_directives(self, batch_script_content: list[str]) -> None:
super()._append_sbatch_directives(batch_script_content)
batch_script_content.extend(
[
"",
"nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )",
"nodes_array=($nodes)",
"master_node=${nodes_array[0]}",
"master_ip=$(srun --nodes=1 --ntasks=1 -w \"$master_node\" hostname --ip-address | awk '{print $1}')",
"",
"echo Nodes: $SLURM_JOB_NODELIST",
"echo Num Nodes: ${#nodes[@]}",
"echo Master Node: $master_node",
"echo Master IP: $master_ip",
"",
]
)

@property
def env_vars_path(self) -> Path:
return self.test_run.output_path / "env_vars.sh"

@property
def launcher_script_path(self) -> Path:
return self.test_run.output_path / LAUNCHER_SCRIPT_FILE_NAME

def node_log_path(self, node_idx: int) -> Path:
return self.test_run.output_path / f"nixl-ep-node-{node_idx}.log"

@property
def stderr_path(self) -> Path:
return self.test_run.output_path / "stderr.txt"

@property
def master_ip_path(self) -> Path:
return self.test_run.output_path / "nixl-ep-master-ip.txt"

def resolve_plan_path(self) -> str:
return str((self.test_run.output_path / GENERATED_PLAN_FILE_NAME).absolute())

Expand Down Expand Up @@ -231,7 +227,7 @@ def generate_wait_for_master_services_function(self) -> str:
}}"""

def _launch_srun_prefix(self, node_idx: int) -> str:
target_arg = "--nodelist=$SLURM_JOB_MASTER_NODE" if node_idx == 0 else f"--relative={node_idx}"
target_arg = f'--nodelist="${{nodes_array[{node_idx}]}}"'
parts = [
*self.gen_srun_prefix(with_num_nodes=False),
"--overlap",
Expand All @@ -248,7 +244,10 @@ def _render_launch(self, launch: NixlEPLaunch) -> str:
log_file = self.node_log_path(launch.node_idx).absolute()
open_mode_arg = " --open-mode=append" if launch.append_output else ""
script = f"source {shlex.quote(str(env_file))}; {command}".replace('"', '\\"')
return f'{self._launch_srun_prefix(launch.node_idx)}{open_mode_arg} --output={log_file} bash -c "{script}"'
return (
f"{self._launch_srun_prefix(launch.node_idx)}{open_mode_arg} "
f'--output={log_file} --error={log_file} bash -c "{script}"'
)

def generate_wait_for_phase_completion_function(self) -> str:
timeout = self.phase_transition_timeout_seconds
Expand Down Expand Up @@ -290,7 +289,7 @@ def _write_env_vars_file(self) -> None:
def _background_launches_lines(self, launches: tuple[NixlEPLaunch, ...]) -> list[str]:
lines: list[str] = []
for launch in launches:
lines.extend([self._render_launch(launch) + " &", "worker_pids+=($!)"])
lines.extend([self._render_launch(launch) + " &", "active_srun_count=$((active_srun_count + 1))"])
return lines

@staticmethod
Expand All @@ -308,8 +307,13 @@ def _wait_for_workers_lines(cls) -> list[str]:
return [
"",
"rc=0",
'for pid in "${worker_pids[@]}"; do',
' wait "$pid" || rc=$?',
'while [ "$active_srun_count" -gt 0 ]; do',
" wait -n",
" wait_rc=$?",
" active_srun_count=$((active_srun_count - 1))",
' if [ "$wait_rc" -ne 0 ] && [ "$rc" -eq 0 ]; then',
" rc=$wait_rc",
" fi",
"done",
"",
*cls._finish_with_rc_lines(),
Expand Down Expand Up @@ -355,12 +359,12 @@ def _initial_stage_lines(self, stage: NixlEPStage, has_followers: bool) -> list[
primary_launch = stage.launches[0]
master_service_lines = self._wait_for_master_services_lines() if has_followers else []
header_lines = [
"worker_pids=()",
"active_srun_count=0",
"",
'echo "Starting initial NIXL EP stage on the master node..."',
self._render_launch(primary_launch) + " &",
"primary_pid=$!",
"worker_pids+=($primary_pid)",
"active_srun_count=$((active_srun_count + 1))",
]
return header_lines + master_service_lines + self._initial_follower_launch_lines(stage)

Expand All @@ -375,20 +379,85 @@ def _followup_stage_lines(self, stage: NixlEPStage) -> list[str]:
]
return header_lines + self._background_launches_lines(stage.launches)

def _gen_srun_command(self) -> str:
self._write_env_vars_file()
self._write_plan_file()
def _launcher_prologue_lines(self) -> list[str]:
num_nodes, node_list = self.get_cached_nodes_spec()
if node_list:
node_setup_lines = [f"nodes_array=( {' '.join(shlex.quote(node) for node in node_list)} )"]
else:
node_setup_lines = [
"nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )",
f'nodes_array=("${{nodes[@]:0:{num_nodes}}}")',
]

master_ip_path = shlex.quote(str(self.master_ip_path.absolute()))
stderr_path = shlex.quote(str(self.stderr_path.absolute()))
return [
"#!/bin/bash",
"",
*node_setup_lines,
"master_node=${nodes_array[0]}",
'export SLURM_JOB_MASTER_NODE="${SLURM_JOB_MASTER_NODE:-$master_node}"',
(
f'srun --nodes=1 --ntasks=1 -N1 --nodelist="$master_node" '
f"--output={master_ip_path} --error={stderr_path} hostname --ip-address"
),
f"master_ip=$(awk '{{print $1}}' {master_ip_path})",
"",
'echo "Nodes: $SLURM_JOB_NODELIST"',
'echo "Num Nodes: ${#nodes_array[@]}"',
'echo "Master Node: $master_node"',
'echo "Master IP: $master_ip"',
"",
]

@staticmethod
def _cleanup_function_lines() -> list[str]:
return [
"cleanup_nixl_ep() {",
" local pids",
' pids="$(jobs -pr)"',
' if [ -z "$pids" ]; then',
" return 0",
" fi",
' echo "Cleaning up NIXL EP background launches..."',
" kill -TERM $pids >/dev/null 2>&1 || true",
" sleep 2",
' pids="$(jobs -pr)"',
' if [ -n "$pids" ]; then',
" kill -KILL $pids >/dev/null 2>&1 || true",
" fi",
" wait >/dev/null 2>&1 || true",
"}",
"",
"on_nixl_ep_signal() {",
' local rc="$1"',
" cleanup_nixl_ep",
' exit "$rc"',
"}",
"",
"trap cleanup_nixl_ep EXIT",
"trap 'on_nixl_ep_signal 130' INT",
"trap 'on_nixl_ep_signal 143' TERM",
"",
]

def _launcher_body(self) -> str:
stages = [stage for stage in self.plan_stages if stage.launches]
if not stages:
raise ValueError("NIXL EP plan does not launch any non-negative ranks.")

first_stage = stages[0]
if len(stages) == 1 and len(first_stage.launches) == 1:
return self._render_single_stage(first_stage)
lines = [
*self._launcher_prologue_lines(),
*self._cleanup_function_lines(),
self._render_single_stage(first_stage),
]
return "\n".join(lines)

has_followers = self._has_follower_launches(stages)
lines = self._plan_helper_function_lines(
lines = [*self._launcher_prologue_lines(), *self._cleanup_function_lines()]
lines += self._plan_helper_function_lines(
has_followers=has_followers,
has_multiple_stages=len(stages) > 1,
)
Expand All @@ -399,3 +468,14 @@ def _gen_srun_command(self) -> str:

lines += self._wait_for_workers_lines()
return "\n".join(lines)

def _write_launcher_script(self) -> None:
self.launcher_script_path.parent.mkdir(parents=True, exist_ok=True)
self.launcher_script_path.write_text(self._launcher_body() + "\n", encoding="utf-8")
self.launcher_script_path.chmod(0o755)

def _gen_srun_command(self) -> str:
self._write_env_vars_file()
self._write_plan_file()
self._write_launcher_script()
return f"bash {shlex.quote(str(self.launcher_script_path.absolute()))}"
124 changes: 124 additions & 0 deletions tests/ref_data/nixl-ep-launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#!/bin/bash

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=("${nodes[@]:0:3}")
master_node=${nodes_array[0]}
export SLURM_JOB_MASTER_NODE="${SLURM_JOB_MASTER_NODE:-$master_node}"
srun --nodes=1 --ntasks=1 -N1 --nodelist="$master_node" --output=__OUTPUT_DIR__/output/nixl-ep-master-ip.txt --error=__OUTPUT_DIR__/output/stderr.txt hostname --ip-address
master_ip=$(awk '{print $1}' __OUTPUT_DIR__/output/nixl-ep-master-ip.txt)

echo "Nodes: $SLURM_JOB_NODELIST"
echo "Num Nodes: ${#nodes_array[@]}"
echo "Master Node: $master_node"
echo "Master IP: $master_ip"

cleanup_nixl_ep() {
local pids
pids="$(jobs -pr)"
if [ -z "$pids" ]; then
return 0
fi
echo "Cleaning up NIXL EP background launches..."
kill -TERM $pids >/dev/null 2>&1 || true
sleep 2
pids="$(jobs -pr)"
if [ -n "$pids" ]; then
kill -KILL $pids >/dev/null 2>&1 || true
fi
wait >/dev/null 2>&1 || true
}

on_nixl_ep_signal() {
local rc="$1"
cleanup_nixl_ep
exit "$rc"
}

trap cleanup_nixl_ep EXIT
trap 'on_nixl_ep_signal 130' INT
trap 'on_nixl_ep_signal 143' TERM

wait_for_master_services() {
local timeout=90
local interval=1
local end_time=$(($(date +%s) + timeout))

while [ "$(date +%s)" -lt "$end_time" ]; do
if timeout 1 bash -c ": > /dev/tcp/$master_ip/9999" >/dev/null 2>&1; then
echo "NIXL EP master services are ready on $master_ip"
return 0
fi
sleep "$interval"
done

echo "Timed out waiting for NIXL EP master services on $master_ip"
return 1
}

wait_for_phase_completion() {
local phase="$1"
local log_file="$2"
local primary_pid="$3"
local timeout=150
local interval=1
local end_time=$(($(date +%s) + timeout))

while [ "$(date +%s)" -lt "$end_time" ]; do
if [ -f "$log_file" ] && grep -Fq -- "-> end phase $phase" "$log_file"; then
echo "Detected completion of phase $phase in $log_file"
return 0
fi
if [ -f "$log_file" ] && grep -Fq -- "no plan phases were found for rank" "$log_file"; then
echo "Detected an early NIXL EP failure while waiting for phase $phase"
return 1
fi
if ! kill -0 "$primary_pid" >/dev/null 2>&1; then
echo "Primary NIXL EP launch exited before phase $phase completed"
return 1
fi
sleep "$interval"
done

echo "Timed out waiting for phase $phase to complete"
return 1
}

active_srun_count=0

echo "Starting initial NIXL EP stage on the master node..."
srun --export=ALL --mpi=pmix --container-image=docker.io/nvidia/nixl-ep:latest --container-mounts=__OUTPUT_DIR__/output:/cloudai_run_results,__INSTALL_DIR__:/cloudai_install,__OUTPUT_DIR__/output --overlap --nodelist="${nodes_array[0]}" --ntasks-per-node=1 --ntasks=1 -N1 --output=__OUTPUT_DIR__/output/nixl-ep-node-0.log --error=__OUTPUT_DIR__/output/nixl-ep-node-0.log bash -c "source __OUTPUT_DIR__/output/env_vars.sh; python3 /workspace/nixl/examples/device/ep/tests/elastic/elastic.py --plan __OUTPUT_DIR__/output/nixl-ep-plan.json --num-processes 4 --disable-ll-nvlink --hidden-dim 8192 --kineto --num-experts-per-rank 4 --num-tokens 256 --num-topk 6" &
primary_pid=$!
active_srun_count=$((active_srun_count + 1))

echo "Waiting for NIXL EP master services..."
wait_for_master_services || exit 1

echo "Waiting for phase 0 before starting phase 1..."
wait_for_phase_completion "0" "__OUTPUT_DIR__/output/nixl-ep-node-0.log" "$primary_pid" || exit 1

echo "Starting launches for phase 1..."
srun --export=ALL --mpi=pmix --container-image=docker.io/nvidia/nixl-ep:latest --container-mounts=__OUTPUT_DIR__/output:/cloudai_run_results,__INSTALL_DIR__:/cloudai_install,__OUTPUT_DIR__/output --overlap --nodelist="${nodes_array[1]}" --ntasks-per-node=1 --ntasks=1 -N1 --open-mode=append --output=__OUTPUT_DIR__/output/nixl-ep-node-1.log --error=__OUTPUT_DIR__/output/nixl-ep-node-1.log bash -c "source __OUTPUT_DIR__/output/env_vars.sh; python3 /workspace/nixl/examples/device/ep/tests/elastic/elastic.py --plan __OUTPUT_DIR__/output/nixl-ep-plan.json --num-processes 4 --tcp-server $master_ip --disable-ll-nvlink --hidden-dim 8192 --kineto --num-experts-per-rank 4 --num-tokens 256 --num-topk 6" &
active_srun_count=$((active_srun_count + 1))

echo "Waiting for phase 2 before starting phase 3..."
wait_for_phase_completion "2" "__OUTPUT_DIR__/output/nixl-ep-node-0.log" "$primary_pid" || exit 1

echo "Starting launches for phase 3..."
srun --export=ALL --mpi=pmix --container-image=docker.io/nvidia/nixl-ep:latest --container-mounts=__OUTPUT_DIR__/output:/cloudai_run_results,__INSTALL_DIR__:/cloudai_install,__OUTPUT_DIR__/output --overlap --nodelist="${nodes_array[2]}" --ntasks-per-node=1 --ntasks=1 -N1 --open-mode=append --output=__OUTPUT_DIR__/output/nixl-ep-node-2.log --error=__OUTPUT_DIR__/output/nixl-ep-node-2.log bash -c "source __OUTPUT_DIR__/output/env_vars.sh; python3 /workspace/nixl/examples/device/ep/tests/elastic/elastic.py --plan __OUTPUT_DIR__/output/nixl-ep-plan.json --num-processes 2 --tcp-server $master_ip --disable-ll-nvlink --hidden-dim 8192 --kineto --num-experts-per-rank 4 --num-tokens 256 --num-topk 6" &
active_srun_count=$((active_srun_count + 1))

rc=0
while [ "$active_srun_count" -gt 0 ]; do
wait -n
wait_rc=$?
active_srun_count=$((active_srun_count - 1))
if [ "$wait_rc" -ne 0 ] && [ "$rc" -eq 0 ]; then
rc=$wait_rc
fi
done

if [ "$rc" -eq 0 ]; then
echo "All NIXL EP launches completed successfully"
fi

exit $rc
Loading
Loading