Skip to content

[Distributed] hardware=gpu does not correctly configure process-per-node mode with Slurm #3433

@olupton

Description

@olupton

Bug report

The logic for distributed initialisation with hardware=gpu:

def initialize_jax_for_gpu(raw_keys):
"""Jax distributed initialize for GPUs."""
if os.environ.get("JAX_COORDINATOR_IP") is not None:
coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP"))
coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT"))
devices = os.getenv("CUDA_VISIBLE_DEVICES")
if devices is not None:
try:
devices = [int(x) for x in devices.split(",")]
except (ValueError, TypeError) as e:
max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}")
devices = None
jax.distributed.initialize(
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
num_processes=int(os.getenv("NNODES")),
process_id=int(os.getenv("NODE_RANK")),
initialization_timeout=raw_keys["jax_distributed_initialization_timeout"],
local_device_ids=devices,
)
max_logging.log(f"JAX global devices: {jax.devices()}")

Is targeted at running in process-per-node mode (i.e. 1 process driving all GPUs in the machine). However, if CUDA_VISIBLE_DEVICES is not set explicitly then this falls through to auto-detection, which assumes process-per-GPU mode on Slurm. The result is that only the 0th GPU on each node is used.

It would be more user-friendly if -- given that this MaxText code is quite explicitly targeted at process-per-node mode -- it did not require the user to explicitly set CUDA_VISIBLE_DEVICES.

There is also hardware=gpu_multiprocess that defers to JAX's default distributed initialisation and, on a Slurm cluster, correctly yields a process-per-GPU configuration.

Logs/Output

srun --container-image=...--container-remap-root sh -c 'CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NNODES=2 NODE_RANK=${SLURM_PROCID} JAX_COORDINATOR_PORT=2222 JAX_COORDINATOR_IP=... python3 -m maxtext.trainers.pre_train.train ... hardware=gpu ici_fsdp_parallelism=8 ...'

will work, while

srun --container-image=...--container-remap-root sh -c 'NNODES=2 NODE_RANK=${SLURM_PROCID} JAX_COORDINATOR_PORT=2222 JAX_COORDINATOR_IP=... python3 -m maxtext.trainers.pre_train.train ... hardware=gpu ici_fsdp_parallelism=8 ...'

will yield

AssertionError: Number of devices per slice 1 does not match the product of the ICI parallelism 8

Environment Information

No response

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions