Skip to content

[CUDA] Fsdp (easy)#3130

Merged
nastya236 merged 15 commits intoml-explore:mainfrom
nastya236:fsdp
Mar 1, 2026
Merged

[CUDA] Fsdp (easy)#3130
nastya236 merged 15 commits intoml-explore:mainfrom
nastya236:fsdp

Conversation

@nastya236
Copy link
Collaborator

@nastya236 nastya236 commented Feb 14, 2026

FSDP without sharding model parameters.

What it does:

  • reduce-scatter the gradients
  • clip the gradients if needed
  • update a slice of parameters
  • all-gather update parameters

It is a bit slower than averaging the gradients directly:

# Params Comm size DDP (ms) FSDP (ms) DDP bf16 (ms) FSDP bf16 (ms)
0.3B 16 MiB 6.96 6.94 5.54 5.38
0.3B 32 MiB 6.81 6.90 5.64 5.38
2.7B 16 MiB 34.06 34.63 24.78 25.94
2.7B 32 MiB 33.54 34.79 24.65 27.89
5.4B 16 MiB 63.07 65.23 45.59 47.44
5.4B 32 MiB 63.07 65.27 45.96 47.57

For 3.7B qwen pretraining on 8 B200 GPUs:

fsdp (no model sharding): its_per_sec: 24.0635, toks_per_sec: 197128.5563, peak_memory: 99.1149
ddp: its_per_sec: 24.4566, toks_per_sec: 200348.4971, peak_memory: 126.8486

And just a test that convergence is identical:

Screenshot 2026-02-14 at 21 28 07

@nastya236 nastya236 changed the title Fsdp Fsdp (easy) Feb 14, 2026
@nastya236 nastya236 marked this pull request as ready for review February 14, 2026 22:11
@nastya236 nastya236 changed the title Fsdp (easy) [CUDA] Fsdp (easy) Feb 14, 2026
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome!

On the implementation side, I am wondering whether the abstraction (although quite smartly written indeed) hinders performance and maybe readability a bit.

Just to be clear, I don't necessarily have a better suggestion but some observations.

The current approach very nicely

  • gathers the gradients in a group until they are large enough to communicate
  • splits each gradient over N nodes with a sum_scatter
  • reshape and split back to separate arrays
  • run the optimizer per array shard
  • gather the new parameters in groups until they are large enough to communicate
  • all_gather the groups
  • reshape and split back to separate arrays

This has several niceties the most important one being that each intermediate step retains its semantic meaning. The sum_scatter results are still per array. The optimizer state is still per array etc.

However, another option would be

  • gathers the gradients in a group until they are large enough to communicate
  • split the whole group (not each gradient) with a sum_scatter
  • run the optimizer per group shard
  • all_gather the group
  • reshape and split back to separate arrays

The above would make the optimizer state as well as the sum_scatter result kind of meaningless without extra metadata but would also skip a bunch of concatenations and reduce the calls to the optimizer possibly increasing the efficiency there as well.

Probably not worth fusing it unless there are significant speedups. Wdyt?

@nastya236
Copy link
Collaborator Author

nastya236 commented Feb 24, 2026

Thanks for your comment @angeloskath, I fully agree, this actually does not make any sense to reshape it back and reconstruct the whole tree. Regarding sharding each gradient, I think there are 2 ways of doing it and I dont really know which one is better..

  1. Each gradient is reshaped to (N, -1) before concatenating, each rank gets a slice of every gradient. We still run 1 sum_scatter per group.
    Something like this:
  [[w1_rank0 | w2_rank0 | w3_rank0],
   [w1_rank1 | w2_rank1 | w3_rank1],
   [w1_rank2 | w2_rank2 | w3_rank2]]  →  sum_scatter gives each rank its row
  rank 0: [w1_rank0 | w2_rank0 | w3_rank0]
  rank 1: [w1_rank1 | w2_rank1 | w3_rank1]
  rank 2: [w1_rank2 | w2_rank2 | w3_rank2]
  1. Flatten everything, concatenate and slice horizontally so each node receive some gradients.
    Something like this:
  [w1_flat | w2_flat | w3_flat]  →  sum_scatter splits into N equal chunks                                  
  rank 0: [w1_all | w2_start...]
  rank 1: [...w2_end | w3_start...]                                                                         
  rank 2: [...w3_end]

My intuition was that sharding every array on axis 0 (first variant) could potentially give more overlap between communication and computation especially for the future version when we shard model parameters. I will try both ways and implement the optimal.

@nastya236
Copy link
Collaborator Author

I think this version is much better than the first one. Thanks @angeloskath!
Now we:

  • group local gradients until the size is smaller than a threshold
  • reshape and concatenate along axis 0
  • sum_scatter so all nodes receive a slice of a group
  • clip in max_norm is passed
  • perform a step (without pytree reconstruction)
  • all_gather updated parameters

I tried the other option (basically the same flatten -> concatenate -> split as for average gradients) and I did not see any difference between average_gradients, fsdp (version ), fsdp (version 2) in terms of tokens per second for 4B qwen on 8 b200 gpus (all run ~200k tokens per second).

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@nastya236 nastya236 merged commit 72e04f7 into ml-explore:main Mar 1, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants