-
Notifications
You must be signed in to change notification settings - Fork 494
Description
Bug report
The logic for distributed initialisation with hardware=gpu:
maxtext/src/maxtext/utils/max_utils.py
Lines 246 to 266 in 37ded59
| def initialize_jax_for_gpu(raw_keys): | |
| """Jax distributed initialize for GPUs.""" | |
| if os.environ.get("JAX_COORDINATOR_IP") is not None: | |
| coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) | |
| coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) | |
| devices = os.getenv("CUDA_VISIBLE_DEVICES") | |
| if devices is not None: | |
| try: | |
| devices = [int(x) for x in devices.split(",")] | |
| except (ValueError, TypeError) as e: | |
| max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}") | |
| devices = None | |
| jax.distributed.initialize( | |
| coordinator_address=f"{coordinator_ip}:{coordinator_port}", | |
| num_processes=int(os.getenv("NNODES")), | |
| process_id=int(os.getenv("NODE_RANK")), | |
| initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], | |
| local_device_ids=devices, | |
| ) | |
| max_logging.log(f"JAX global devices: {jax.devices()}") |
Is targeted at running in process-per-node mode (i.e. 1 process driving all GPUs in the machine). However, if CUDA_VISIBLE_DEVICES is not set explicitly then this falls through to auto-detection, which assumes process-per-GPU mode on Slurm. The result is that only the 0th GPU on each node is used.
It would be more user-friendly if -- given that this MaxText code is quite explicitly targeted at process-per-node mode -- it did not require the user to explicitly set CUDA_VISIBLE_DEVICES.
There is also hardware=gpu_multiprocess that defers to JAX's default distributed initialisation and, on a Slurm cluster, correctly yields a process-per-GPU configuration.
Logs/Output
srun --container-image=...--container-remap-root sh -c 'CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NNODES=2 NODE_RANK=${SLURM_PROCID} JAX_COORDINATOR_PORT=2222 JAX_COORDINATOR_IP=... python3 -m maxtext.trainers.pre_train.train ... hardware=gpu ici_fsdp_parallelism=8 ...'
will work, while
srun --container-image=...--container-remap-root sh -c 'NNODES=2 NODE_RANK=${SLURM_PROCID} JAX_COORDINATOR_PORT=2222 JAX_COORDINATOR_IP=... python3 -m maxtext.trainers.pre_train.train ... hardware=gpu ici_fsdp_parallelism=8 ...'
will yield
AssertionError: Number of devices per slice 1 does not match the product of the ICI parallelism 8
Environment Information
No response
Additional Context
No response