Conversation
angeloskath
left a comment
There was a problem hiding this comment.
This is awesome!
On the implementation side, I am wondering whether the abstraction (although quite smartly written indeed) hinders performance and maybe readability a bit.
Just to be clear, I don't necessarily have a better suggestion but some observations.
The current approach very nicely
- gathers the gradients in a group until they are large enough to communicate
- splits each gradient over N nodes with a
sum_scatter - reshape and split back to separate arrays
- run the optimizer per array shard
- gather the new parameters in groups until they are large enough to communicate
all_gatherthe groups- reshape and split back to separate arrays
This has several niceties the most important one being that each intermediate step retains its semantic meaning. The sum_scatter results are still per array. The optimizer state is still per array etc.
However, another option would be
- gathers the gradients in a group until they are large enough to communicate
- split the whole group (not each gradient) with a
sum_scatter - run the optimizer per group shard
all_gatherthe group- reshape and split back to separate arrays
The above would make the optimizer state as well as the sum_scatter result kind of meaningless without extra metadata but would also skip a bunch of concatenations and reduce the calls to the optimizer possibly increasing the efficiency there as well.
Probably not worth fusing it unless there are significant speedups. Wdyt?
|
Thanks for your comment @angeloskath, I fully agree, this actually does not make any sense to reshape it back and reconstruct the whole tree. Regarding sharding each gradient, I think there are 2 ways of doing it and I dont really know which one is better..
My intuition was that sharding every array on axis 0 (first variant) could potentially give more overlap between communication and computation especially for the future version when we shard model parameters. I will try both ways and implement the optimal. |
|
I think this version is much better than the first one. Thanks @angeloskath!
I tried the other option (basically the same flatten -> concatenate -> split as for average gradients) and I did not see any difference between |
FSDP without sharding model parameters.
What it does:
reduce-scatterthe gradientsall-gatherupdate parametersIt is a bit slower than averaging the gradients directly:
For 3.7B qwen pretraining on 8 B200 GPUs:
fsdp(no model sharding): its_per_sec: 24.0635, toks_per_sec: 197128.5563, peak_memory: 99.1149ddp: its_per_sec: 24.4566, toks_per_sec: 200348.4971, peak_memory: 126.8486And just a test that convergence is identical: