Skip to content

Fix NaN weights on non-rank-0 FSDP processes#45050

Merged
ArthurZucker merged 7 commits intohuggingface:mainfrom
albertvillanova:fix-trl-5386
Apr 13, 2026
Merged

Fix NaN weights on non-rank-0 FSDP processes#45050
ArthurZucker merged 7 commits intohuggingface:mainfrom
albertvillanova:fix-trl-5386

Conversation

@albertvillanova
Copy link
Copy Markdown
Member

@albertvillanova albertvillanova commented Mar 27, 2026

Fix NaN weights on non-rank-0 FSDP processes by using zeros_like instead of empty_like in _move_missing_keys_from_meta_to_device

Follow-up to:

See related downstream issue in trl :

I have checked this fix downstream and it fixes the trl issue:

tests/distributed/test_distributed.py::TestDistributed::test_rloo[fsdp2] PASSED

What does this PR do?

When using FSDP with cpu_ram_efficient_loading, non-rank-0 processes do not load model weights from disk. In _move_missing_keys_from_meta_to_device, placeholder tensors are created for those ranks using torch.empty_like. Uninitialized memory (especially in bfloat16) can contain NaN values (e.g. 0x7FC0 is NaN in bfloat16).

Before #44473, _initialize_missing_keys would call initialize_weights() unconditionally, so _init_weights would replace those garbage tensors with random-but-valid values before the FSDP rank-0 broadcast overwrote them with correct weights.

After #44473, _initialize_missing_keys pre-marks all state_dict parameters as _is_hf_initialized = True on non-rank-0 processes, causing initialize_weights() to skip them entirely. This is intentional (the rank-0 broadcast provides correct values) but it removes the safe fallback. If the tensors contain NaN before the broadcast, anything that touches the model in that window (or any edge case where broadcast behavior differs) produces NaN outputs.

In practice this manifests as a CUDA device-side assert during generation:

RuntimeError: probability tensor contains either inf, nan or element < 0

because torch.multinomial is called on logits derived from NaN weights.

Solution

Replace torch.empty_like with torch.zeros_like in _move_missing_keys_from_meta_to_device. Zero is a valid floating-point value in all dtypes (fp32, fp16, bf16), so placeholder tensors are always safe before the broadcast overwrites them.

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@winglian, who created the PR; and @Cyrilvallez, who reviewed it

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@albertvillanova
Copy link
Copy Markdown
Member Author

If this PR is approved, I think transformers may require a patch release 5.4.1.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure the shard and distribute is failing and not the rank 0 tensor?

for key in missing_keys - self.all_tied_weights_keys.keys():
param = self.get_parameter_or_buffer(key)
param_device = get_device(device_map, key, valid_torch_device=True)
value = torch.empty_like(param, device=param_device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure the fix is here rather no? if missing is on rank 0 we broadcast Nans

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your critical review, @ArthurZucker 🤗

I think for the FSDP cpu_ram_efficient_loading bug reproduced in huggingface/trl#5386, non-rank-0 returns early in _move_missing_keys_from_meta_to_device(), so it never reaches the line 4527 you pointed out above. The problematic tensors are the non-rank-0 placeholders created in the FSDP early-return branch, which are then marked initialized and skipped by _initialize_missing_keys(). I agree line 4527 may deserve a separate look, but it does not seem to be the path hit by this repro.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we agree that fsdp is supposed to broadcast into them no?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course the non-0 ranks are not going where I pointed, but I am saying if we fix what I point, what's broadcasted should be good to go?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reason I don't want this is because its costly!

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested just fixing this line:

--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -4526,7 +4526,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
         for key in missing_keys - self.all_tied_weights_keys.keys():
             param = self.get_parameter_or_buffer(key)
             param_device = get_device(device_map, key, valid_torch_device=True)
-            value = torch.empty_like(param, device=param_device)
+            value = torch.zeros_like(param, device=param_device)
             # For TP, we may need to shard the param
             if device_mesh is not None:
                 shard_and_distribute_module(

and the error is not fixed:

/pytorch/aten/src/ATen/native/cuda/TensorCompare.cu:112: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `probability tensor contains either `inf`, `nan` or element < 0` failed.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/fsx/albert/dev/trl/trl/scripts/rloo.py", line 165, in <module>
[rank0]:     main(script_args, training_args, model_args, dataset_args)
[rank0]:   File "/fsx/albert/dev/trl/trl/scripts/rloo.py", line 137, in main
[rank0]:     trainer.train()
[rank0]:   File "/fsx/albert/dev/transformers/src/transformers/trainer.py", line 1424, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/fsx/albert/dev/transformers/src/transformers/trainer.py", line 1506, in _inner_training_loop
[rank0]:     self._run_epoch(
[rank0]:   File "/fsx/albert/dev/transformers/src/transformers/trainer.py", line 1734, in _run_epoch
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 748, in training_step
[rank0]:     output = super().training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/fsx/albert/dev/transformers/src/transformers/trainer.py", line 1900, in training_step
[rank0]:     inputs = self._prepare_inputs(inputs)
[rank0]:   File "/fsx/albert/dev/trl/trl/extras/profiling.py", line 202, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:   File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 777, in _prepare_inputs
[rank0]:     generation_batch = self._generate_and_score_completions(generation_batch)
[rank0]:   File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 1118, in _generate_and_score_completions
[rank0]:     prompt_ids_list, completion_ids_list, completions = self._generate(prompts)
[rank0]:   File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 1043, in _generate
[rank0]:     completion_ids = self._generate_single_turn(prompt_ids, images, multimodal_fields)
[rank0]:   File "/fsx/albert/dev/trl/trl/trainer/rloo_trainer.py", line 1016, in _generate_single_turn
[rank0]:     prompt_completion_ids = unwrapped_model.generate(
[rank0]:   File "/fsx/albert/dev/trl/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py", line 689, in wrapped_method
[rank0]:     out = orig_method(*args, **kwargs)
[rank0]:   File "/fsx/albert/dev/trl/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/fsx/albert/dev/transformers/src/transformers/generation/utils.py", line 2543, in generate
[rank0]:     result = decoding_method(
[rank0]:   File "/fsx/albert/dev/transformers/src/transformers/generation/utils.py", line 2791, in _sample
[rank0]:     next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
[rank0]: torch.AcceleratorError: CUDA error: device-side assert triggered
[rank0]: Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Copy link
Copy Markdown
Member Author

@albertvillanova albertvillanova Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rank-0 does NOT go through line 4527 for regular params:

  • missing_keys is empty (all params are in the checkpoint): missing_keys=set()
  • The loop at 4524-4527 iterates over nothing.
  • Rank-0 loads correct values from disk and broadcasts correct values.
  • Fixing empty_likezeros_like at 4527 changes nothing at all for the failing test

The broadcast is NOT the mechanism that made the test pass before PR #44473:

If the broadcast were correctly overwriting non-rank-0's values in both cases, the test outcome would be identical before and after PR #44473 (correct weights either way). The fact that before PR #44473 the test passed with random weights (not correct weights) proves the broadcast is NOT fixing non-rank-0's state_dict params.

The fix must be at 4512-4519 to ensure non-rank-0 has valid placeholder values before _init_weights is skipped.


Regarding the cost concern: this is a one-time cost at model initialization (not per training step), on non-rank-0 processes only, in a code path that only activates when cpu_ram_efficient_loading is in use. It is dwarfed by the cost of loading the model itself.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ArthurZucker, do you agree with my arguments above? Is it OK to merge this PR? Thanks.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! ty for investigating

@ArthurZucker ArthurZucker added this pull request to the merge queue Apr 13, 2026
Merged via the queue into huggingface:main with commit ff49f7c Apr 13, 2026
28 checks passed
@albertvillanova
Copy link
Copy Markdown
Member Author

Thanks for your review, @ArthurZucker.

I don't know the transformers policy, but are you planning to include this fix in:

  • patch release v5.4.1? The bug was introduced in v5.4.0
  • patch release v5.5.4?
  • next minor release v5.6.0?

sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
Init with zeros instead of empty in _move_missing_keys_from_meta_to_device
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants