Fix NaN weights on non-rank-0 FSDP processes#45050
Fix NaN weights on non-rank-0 FSDP processes#45050ArthurZucker merged 7 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
If this PR is approved, I think |
ArthurZucker
left a comment
There was a problem hiding this comment.
Are we sure the shard and distribute is failing and not the rank 0 tensor?
| for key in missing_keys - self.all_tied_weights_keys.keys(): | ||
| param = self.get_parameter_or_buffer(key) | ||
| param_device = get_device(device_map, key, valid_torch_device=True) | ||
| value = torch.empty_like(param, device=param_device) |
There was a problem hiding this comment.
I am pretty sure the fix is here rather no? if missing is on rank 0 we broadcast Nans
There was a problem hiding this comment.
Thanks a lot for your critical review, @ArthurZucker 🤗
I think for the FSDP cpu_ram_efficient_loading bug reproduced in huggingface/trl#5386, non-rank-0 returns early in _move_missing_keys_from_meta_to_device(), so it never reaches the line 4527 you pointed out above. The problematic tensors are the non-rank-0 placeholders created in the FSDP early-return branch, which are then marked initialized and skipped by _initialize_missing_keys(). I agree line 4527 may deserve a separate look, but it does not seem to be the path hit by this repro.
There was a problem hiding this comment.
but we agree that fsdp is supposed to broadcast into them no?
There was a problem hiding this comment.
Of course the non-0 ranks are not going where I pointed, but I am saying if we fix what I point, what's broadcasted should be good to go?
There was a problem hiding this comment.
the reason I don't want this is because its costly!
There was a problem hiding this comment.
I tested just fixing this line:
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -4526,7 +4526,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
for key in missing_keys - self.all_tied_weights_keys.keys():
param = self.get_parameter_or_buffer(key)
param_device = get_device(device_map, key, valid_torch_device=True)
- value = torch.empty_like(param, device=param_device)
+ value = torch.zeros_like(param, device=param_device)
# For TP, we may need to shard the param
if device_mesh is not None:
shard_and_distribute_module(and the error is not fixed:
/pytorch/aten/src/ATen/native/cuda/TensorCompare.cu:112: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `probability tensor contains either `inf`, `nan` or element < 0` failed.
[rank0]: Traceback (most recent call last):
[rank0]: File "/fsx/albert/dev/trl/trl/scripts/rloo.py", line 165, in <module>
[rank0]: main(script_args, training_args, model_args, dataset_args)
[rank0]: File "/fsx/albert/dev/trl/trl/scripts/rloo.py", line 137, in main
[rank0]: trainer.train()
[rank0]: File "/fsx/albert/dev/transformers/src/transformers/trainer.py", line 1424, in train
[rank0]: return inner_training_loop(
[rank0]: File "/fsx/albert/dev/transformers/src/transformers/trainer.py", line 1506, in _inner_training_loop
[rank0]: self._run_epoch(
[rank0]: File "/fsx/albert/dev/transformers/src/transformers/trainer.py", line 1734, in _run_epoch
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 748, in training_step
[rank0]: output = super().training_step(model, inputs, num_items_in_batch)
[rank0]: File "/fsx/albert/dev/transformers/src/transformers/trainer.py", line 1900, in training_step
[rank0]: inputs = self._prepare_inputs(inputs)
[rank0]: File "/fsx/albert/dev/trl/trl/extras/profiling.py", line 202, in wrapper
[rank0]: return func(self, *args, **kwargs)
[rank0]: File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 777, in _prepare_inputs
[rank0]: generation_batch = self._generate_and_score_completions(generation_batch)
[rank0]: File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 1118, in _generate_and_score_completions
[rank0]: prompt_ids_list, completion_ids_list, completions = self._generate(prompts)
[rank0]: File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 1043, in _generate
[rank0]: completion_ids = self._generate_single_turn(prompt_ids, images, multimodal_fields)
[rank0]: File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 1016, in _generate_single_turn
[rank0]: prompt_completion_ids = unwrapped_model.generate(
[rank0]: File "/fsx/albert/dev/trl/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py", line 689, in wrapped_method
[rank0]: out = orig_method(*args, **kwargs)
[rank0]: File "/fsx/albert/dev/trl/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/fsx/albert/dev/transformers/src/transformers/generation/utils.py", line 2543, in generate
[rank0]: result = decoding_method(
[rank0]: File "/fsx/albert/dev/transformers/src/transformers/generation/utils.py", line 2791, in _sample
[rank0]: next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
[rank0]: torch.AcceleratorError: CUDA error: device-side assert triggered
[rank0]: Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.There was a problem hiding this comment.
Rank-0 does NOT go through line 4527 for regular params:
missing_keysis empty (all params are in the checkpoint):missing_keys=set()- The loop at 4524-4527 iterates over nothing.
- Rank-0 loads correct values from disk and broadcasts correct values.
- Fixing
empty_like→zeros_likeat 4527 changes nothing at all for the failing test
The broadcast is NOT the mechanism that made the test pass before PR #44473:
- Before PR fix FSDP loading with meta devices #44473:
- non-rank-0 has
empty_likegarbage →_init_weightsruns → non-rank-0 gets random (not correct) weights → test passes
- non-rank-0 has
- After PR fix FSDP loading with meta devices #44473:
- non-rank-0 has
empty_likegarbage →_init_weightsskipped (params marked_is_hf_initialized=True) → non-rank-0 keeps NaN → test fails
- non-rank-0 has
If the broadcast were correctly overwriting non-rank-0's values in both cases, the test outcome would be identical before and after PR #44473 (correct weights either way). The fact that before PR #44473 the test passed with random weights (not correct weights) proves the broadcast is NOT fixing non-rank-0's state_dict params.
The fix must be at 4512-4519 to ensure non-rank-0 has valid placeholder values before _init_weights is skipped.
Regarding the cost concern: this is a one-time cost at model initialization (not per training step), on non-rank-0 processes only, in a code path that only activates when cpu_ram_efficient_loading is in use. It is dwarfed by the cost of loading the model itself.
There was a problem hiding this comment.
Hi @ArthurZucker, do you agree with my arguments above? Is it OK to merge this PR? Thanks.
|
Thanks for your review, @ArthurZucker. I don't know the
|
Init with zeros instead of empty in _move_missing_keys_from_meta_to_device
Fix NaN weights on non-rank-0 FSDP processes by using
zeros_likeinstead ofempty_likein_move_missing_keys_from_meta_to_deviceFollow-up to:
See related downstream issue in
trl:inf,nanor element < 0 trl#5386I have checked this fix downstream and it fixes the
trlissue:What does this PR do?
When using FSDP with
cpu_ram_efficient_loading, non-rank-0 processes do not load model weights from disk. In_move_missing_keys_from_meta_to_device, placeholder tensors are created for those ranks usingtorch.empty_like. Uninitialized memory (especially in bfloat16) can contain NaN values (e.g. 0x7FC0 is NaN in bfloat16).Before #44473,
_initialize_missing_keyswould callinitialize_weights()unconditionally, so_init_weightswould replace those garbage tensors with random-but-valid values before the FSDP rank-0 broadcast overwrote them with correct weights.After #44473,
_initialize_missing_keyspre-marks allstate_dictparameters as_is_hf_initialized = Trueon non-rank-0 processes, causinginitialize_weights()to skip them entirely. This is intentional (the rank-0 broadcast provides correct values) but it removes the safe fallback. If the tensors contain NaN before the broadcast, anything that touches the model in that window (or any edge case where broadcast behavior differs) produces NaN outputs.In practice this manifests as a CUDA device-side assert during generation:
because
torch.multinomialis called on logits derived from NaN weights.Solution
Replace
torch.empty_likewithtorch.zeros_likein_move_missing_keys_from_meta_to_device. Zero is a valid floating-point value in all dtypes (fp32, fp16, bf16), so placeholder tensors are always safe before the broadcast overwrites them.Code Agent Policy
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@winglian, who created the PR; and @Cyrilvallez, who reviewed it