Skip to content

Commit 7ec07ba

Browse files
author
cloudforge1
committed
【Hackathon 9th No.39】Unit test for moe_expert_ffn_wint2
1 parent 30f9f33 commit 7ec07ba

1 file changed

Lines changed: 309 additions & 0 deletions

File tree

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for moe_expert_ffn_wint2 custom op.
16+
17+
Tests the CUTLASS Weight-Only INT2 quantized MoE FFN operator:
18+
1) First GEMM: input x dequant(up_gate_proj_weight) -> fc1_out
19+
2) SwiGLU activation: fc1_out -> act_out
20+
3) Second GEMM: act_out x dequant(down_proj_weight) -> output
21+
22+
Reference source for the WINT2 dequant algorithm:
23+
- Triton kernel: fastdeploy/model_executor/ops/triton_ops/wint2_fused_moe_kernel.py
24+
- CUTLASS layout: fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py
25+
"""
26+
27+
import unittest
28+
29+
import numpy as np
30+
import paddle
31+
32+
from fastdeploy.model_executor.ops.gpu import moe_expert_ffn_wint2
33+
34+
paddle.seed(2026)
35+
np.random.seed(2026)
36+
37+
38+
# ---------------------------------------------------------------------------
39+
# Helpers
40+
# ---------------------------------------------------------------------------
41+
42+
43+
def _cutlass_rearrange(w):
44+
"""Apply CUTLASS WINT2 weight layout rearrangement.
45+
46+
Matches CutlassWint2FusedMoeMethod.process_weights_after_loading():
47+
reshape [E, K//16, 16, N//8, 8] -> transpose [0,3,1,4,2] -> reshape
48+
"""
49+
shape = w.shape
50+
E, Kp, N = shape
51+
w = w.reshape([E, Kp // 16, 16, N // 8, 8])
52+
w = paddle.transpose(w, perm=[0, 3, 1, 4, 2])
53+
return w.reshape(shape)
54+
55+
56+
def _build_inputs(
57+
num_experts,
58+
hidden_size,
59+
inter_size,
60+
tokens_per_expert,
61+
dtype="bfloat16",
62+
use_3d=False,
63+
zero_input=False,
64+
):
65+
"""Create correctly-shaped tensors for moe_expert_ffn_wint2.
66+
67+
Args:
68+
num_experts: Number of experts.
69+
hidden_size: Hidden dimension (must be divisible by 128).
70+
inter_size: Intermediate size after SwiGLU split.
71+
tokens_per_expert: List of token counts per expert.
72+
dtype: "bfloat16" or "float16".
73+
use_3d: Use 3D input [E, max_tokens, H] instead of 2D.
74+
zero_input: Set input to zeros (for zero-input invariant test).
75+
"""
76+
gated_inter = inter_size * 2
77+
total_tokens = sum(tokens_per_expert)
78+
79+
# --- Input ---
80+
if use_3d:
81+
max_tok = max(tokens_per_expert) if tokens_per_expert else 1
82+
shape = [num_experts, max_tok, hidden_size]
83+
else:
84+
shape = [total_tokens, hidden_size]
85+
if zero_input:
86+
permute_input = paddle.zeros(shape, dtype=dtype)
87+
else:
88+
permute_input = paddle.randn(shape, dtype=dtype)
89+
90+
# --- Prefix sum ---
91+
tokens_expert_prefix_sum = paddle.to_tensor(np.cumsum(tokens_per_expert).astype("int64"))
92+
93+
# --- Packed uint8 weights with CUTLASS rearrangement ---
94+
w_up = _cutlass_rearrange(
95+
paddle.randint(0, 256, [num_experts, hidden_size // 4, gated_inter], dtype="int32").cast("uint8")
96+
)
97+
w_down = _cutlass_rearrange(
98+
paddle.randint(0, 256, [num_experts, inter_size // 4, hidden_size], dtype="int32").cast("uint8")
99+
)
100+
101+
# --- Super scales (channel-wise, input dtype) ---
102+
super_up = paddle.randn([num_experts, gated_inter], dtype=dtype) * 0.01
103+
super_down = paddle.randn([num_experts, hidden_size], dtype=dtype) * 0.01
104+
105+
# --- Local scales (group-wise, uint8) ---
106+
local_up = paddle.randint(0, 256, [num_experts, hidden_size // 128, gated_inter], dtype="int32").cast("uint8")
107+
local_down = paddle.randint(0, 256, [num_experts, inter_size // 128, hidden_size], dtype="int32").cast("uint8")
108+
109+
# --- Code scale and zero-point (channel-wise, float32) ---
110+
code_scale_up = paddle.randn([num_experts, gated_inter], dtype="float32") * 0.01
111+
code_zp_up = paddle.randn([num_experts, gated_inter], dtype="float32") * 0.01
112+
code_scale_down = paddle.randn([num_experts, hidden_size], dtype="float32") * 0.01
113+
code_zp_down = paddle.randn([num_experts, hidden_size], dtype="float32") * 0.01
114+
115+
return dict(
116+
permute_input=permute_input,
117+
tokens_expert_prefix_sum=tokens_expert_prefix_sum,
118+
up_gate_proj_weight=w_up,
119+
down_proj_weight=w_down,
120+
up_gate_proj_bias=None,
121+
up_gate_proj_scale=super_up,
122+
down_proj_scale=super_down,
123+
up_gate_proj_local_scale=local_up,
124+
up_gate_proj_code_scale=code_scale_up,
125+
up_gate_proj_code_zp=code_zp_up,
126+
down_proj_local_scale=local_down,
127+
down_proj_code_scale=code_scale_down,
128+
down_proj_code_zp=code_zp_down,
129+
)
130+
131+
132+
def _call_op(inputs, used_in_ep_low_latency=False):
133+
"""Invoke moe_expert_ffn_wint2 with the given inputs dict."""
134+
return moe_expert_ffn_wint2(
135+
inputs["permute_input"],
136+
inputs["tokens_expert_prefix_sum"],
137+
inputs["up_gate_proj_weight"],
138+
inputs["down_proj_weight"],
139+
inputs["up_gate_proj_bias"],
140+
inputs["up_gate_proj_scale"],
141+
inputs["down_proj_scale"],
142+
inputs["up_gate_proj_local_scale"],
143+
inputs["up_gate_proj_code_scale"],
144+
inputs["up_gate_proj_code_zp"],
145+
inputs["down_proj_local_scale"],
146+
inputs["down_proj_code_scale"],
147+
inputs["down_proj_code_zp"],
148+
used_in_ep_low_latency,
149+
)
150+
151+
152+
# ===================================================================
153+
# Test Cases
154+
# ===================================================================
155+
156+
157+
class TestMoeExpertFFNWint2(unittest.TestCase):
158+
"""Correctness and regression tests for the WINT2 MoE FFN op."""
159+
160+
# Small dimensions for fast CI (all must be divisible by 128)
161+
E = 4
162+
H = 256
163+
INTER = 128
164+
TOKENS = [4, 6, 2, 4] # per expert, total = 16
165+
166+
def setUp(self):
167+
paddle.set_device("gpu")
168+
169+
# -- Numerical correctness -----------------------------------------
170+
171+
def test_zero_input_produces_zero_output(self):
172+
"""Zero input => matmul=0, SwiGLU(0)=0, matmul=0 => output = 0.
173+
174+
This is a mathematical invariant independent of weight values.
175+
"""
176+
for dtype in ["bfloat16", "float16"]:
177+
with self.subTest(dtype=dtype):
178+
inputs = _build_inputs(
179+
self.E,
180+
self.H,
181+
self.INTER,
182+
self.TOKENS,
183+
dtype=dtype,
184+
zero_input=True,
185+
)
186+
out = _call_op(inputs).cast("float32").numpy()
187+
np.testing.assert_allclose(
188+
out,
189+
np.zeros_like(out),
190+
atol=1e-5,
191+
err_msg=f"Zero input must produce zero output ({dtype})",
192+
)
193+
194+
def test_determinism(self):
195+
"""Identical inputs must produce bit-identical outputs."""
196+
inputs = _build_inputs(self.E, self.H, self.INTER, self.TOKENS)
197+
out1 = _call_op(inputs).cast("float32").numpy()
198+
out2 = _call_op(inputs).cast("float32").numpy()
199+
np.testing.assert_array_equal(
200+
out1,
201+
out2,
202+
err_msg="Non-deterministic: two runs with same inputs differ",
203+
)
204+
205+
def test_nonzero_input_gives_finite_nonzero_output(self):
206+
"""Random non-zero inputs must produce finite, non-zero values."""
207+
inputs = _build_inputs(self.E, self.H, self.INTER, self.TOKENS)
208+
out = _call_op(inputs).cast("float32").numpy()
209+
self.assertTrue(np.all(np.isfinite(out)), "Output contains NaN or Inf")
210+
self.assertGreater(
211+
np.abs(out).max(),
212+
0,
213+
"All-zero output from non-zero input",
214+
)
215+
216+
# -- Shape and dtype -----------------------------------------------
217+
218+
def test_output_shape_2d(self):
219+
"""2D input [total_tokens, H] => output shape matches."""
220+
inputs = _build_inputs(self.E, self.H, self.INTER, self.TOKENS)
221+
out = _call_op(inputs)
222+
self.assertEqual(list(out.shape), list(inputs["permute_input"].shape))
223+
self.assertEqual(out.dtype, inputs["permute_input"].dtype)
224+
225+
def test_output_shape_3d(self):
226+
"""3D input [E, max_tokens, H] => output shape matches."""
227+
inputs = _build_inputs(
228+
self.E,
229+
self.H,
230+
self.INTER,
231+
self.TOKENS,
232+
use_3d=True,
233+
)
234+
out = _call_op(inputs)
235+
self.assertEqual(list(out.shape), list(inputs["permute_input"].shape))
236+
237+
def test_dtype_bf16(self):
238+
"""Op supports bfloat16 input/output."""
239+
inputs = _build_inputs(
240+
self.E,
241+
self.H,
242+
self.INTER,
243+
self.TOKENS,
244+
dtype="bfloat16",
245+
)
246+
out = _call_op(inputs)
247+
self.assertEqual(out.dtype, paddle.bfloat16)
248+
249+
def test_dtype_fp16(self):
250+
"""Op supports float16 input/output."""
251+
inputs = _build_inputs(
252+
self.E,
253+
self.H,
254+
self.INTER,
255+
self.TOKENS,
256+
dtype="float16",
257+
)
258+
out = _call_op(inputs)
259+
self.assertEqual(out.dtype, paddle.float16)
260+
261+
# -- Edge cases ----------------------------------------------------
262+
263+
def test_sparse_experts(self):
264+
"""Experts with zero tokens are handled correctly."""
265+
sparse = [8, 0, 0, 8]
266+
inputs = _build_inputs(self.E, self.H, self.INTER, sparse)
267+
out = _call_op(inputs)
268+
self.assertEqual(list(out.shape), list(inputs["permute_input"].shape))
269+
self.assertTrue(np.all(np.isfinite(out.cast("float32").numpy())))
270+
271+
def test_single_token_single_expert(self):
272+
"""Minimal case: 1 expert, 1 token."""
273+
inputs = _build_inputs(1, self.H, self.INTER, [1])
274+
out = _call_op(inputs)
275+
self.assertEqual(list(out.shape), [1, self.H])
276+
277+
def test_low_latency_mode(self):
278+
"""Low-latency mode (GroupSwigluWithMasked) with 3D input."""
279+
inputs = _build_inputs(
280+
self.E,
281+
self.H,
282+
self.INTER,
283+
self.TOKENS,
284+
use_3d=True,
285+
)
286+
out = _call_op(inputs, used_in_ep_low_latency=True)
287+
self.assertEqual(list(out.shape), list(inputs["permute_input"].shape))
288+
# In 3D mode, padded slots (beyond each expert's token count) may
289+
# overflow in the first GEMM before GroupSwigluWithMasked zeros them,
290+
# causing NaN propagation. Only validate the unpadded positions.
291+
out_np = out.cast("float32").numpy()
292+
for i, n_tok in enumerate(self.TOKENS):
293+
valid = out_np[i, :n_tok, :]
294+
self.assertTrue(
295+
np.all(np.isfinite(valid)),
296+
f"Expert {i}: non-finite values in first {n_tok} valid tokens",
297+
)
298+
299+
def test_uneven_tokens(self):
300+
"""Different number of tokens per expert."""
301+
uneven = [1, 5, 3, 7]
302+
inputs = _build_inputs(self.E, self.H, self.INTER, uneven)
303+
out = _call_op(inputs)
304+
self.assertEqual(list(out.shape), [sum(uneven), self.H])
305+
self.assertTrue(np.all(np.isfinite(out.cast("float32").numpy())))
306+
307+
308+
if __name__ == "__main__":
309+
unittest.main()

0 commit comments

Comments
 (0)