Hi team,
Thank you for sharing your code.
Could you please elaborate how you are performing Expert Parallelism: We shard MLP experts across all devices to fit Hessians into VRAM, required for GPTQ calibration. Each process stores only a fraction of expert layers and corresponding Hessians. ?
In the current codebase, the number of block keys in the main process is equal to number of keys in the other GPUs, i.e, (https://github.com/IST-DASLab/MoE-Quant/blob/master/quant.py#L181 and https://github.com/IST-DASLab/MoE-Quant/blob/master/quant.py#L187). Thus, each rank stores an entire copy of the state dict - in contrast to idea of expert parallelism where MLP experts are sharded.
Ask: Could you please share a way how we can confirm that the experts are sharded across the ranks?
For an 8-GPU setup, here's the output for block_idx = 3 (starting from 0)
Main process num states: 110 Device: 1 num states: 110 Device: 2 num states: 110 Device: 3 num states: 110 Device: 4 num states: 110 Device: 5 num states: 110 Device: 6 num states: 110 Device: 7 num states: 110
Tagging @Godofnothing @eldarkurtic for reach. Thank you.
Hi team,
Thank you for sharing your code.
Could you please elaborate how you are performing
Expert Parallelism: We shard MLP experts across all devices to fit Hessians into VRAM, required for GPTQ calibration. Each process stores only a fraction of expert layers and corresponding Hessians.?In the current codebase, the number of block keys in the main process is equal to number of keys in the other GPUs, i.e, (https://github.com/IST-DASLab/MoE-Quant/blob/master/quant.py#L181 and https://github.com/IST-DASLab/MoE-Quant/blob/master/quant.py#L187). Thus, each rank stores an entire copy of the state dict - in contrast to idea of expert parallelism where MLP experts are sharded.
Ask: Could you please share a way how we can confirm that the experts are sharded across the ranks?
For an 8-GPU setup, here's the output for block_idx = 3 (starting from 0)
Main process num states: 110 Device: 1 num states: 110 Device: 2 num states: 110 Device: 3 num states: 110 Device: 4 num states: 110 Device: 5 num states: 110 Device: 6 num states: 110 Device: 7 num states: 110Tagging @Godofnothing @eldarkurtic for reach. Thank you.