You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
transformer_engine.pytorch.ops.GroupedLinear fails with MXFP8BlockScaling in a Mixtral MoE expert-parallel run when per-expert token splits are not divisible by 32.
[rank0]: File "/lustre/fsw/coreai_prod_infbench/faradawny/TransformerEngine/docs/examples/te_mixtral/te_mixtral.py", line 687, in _expert_ffn
[rank0]: gate_up_output = self.experts_gate_up(tokens, split_sizes)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/ops/op.py", line 522, in forward
[rank0]: return OperationFuser([self])(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/ops/basic/grouped_linear.py", line 739, in fuser_forward
[rank0]: xs = tex.split_quantize(x, split_sizes_int, input_quantizers)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: /workspace/TransformerEngine/transformer_engine/pytorch/csrc/quantizer.cpp:1668 in function get_scale_shape: Assertion failed: last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0. MXFP8 requires tensor dims that are divisible by 32 (got shape=(2283,4096))
Other ranks fail similarly with shapes like (1441,4096), (1178,4096), and (1225,4096).
Expected
The Sequential Ops grouped path should either handle MXFP8 padding internally per split, or provide a clear documented requirement/workaround for MoE token splits whose per-expert token counts are not multiples of 32.
Summary
transformer_engine.pytorch.ops.GroupedLinearfails withMXFP8BlockScalingin a Mixtral MoE expert-parallel run when per-expert token splits are not divisible by 32.Repro context
MXFP8BlockScalingtransformer_engine.pytorch.ops.GroupedLinearRepro command:
Error
Other ranks fail similarly with shapes like
(1441,4096),(1178,4096), and(1225,4096).Expected
The Sequential Ops grouped path should either handle MXFP8 padding internally per split, or provide a clear documented requirement/workaround for MoE token splits whose per-expert token counts are not multiples of 32.