Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 178 additions & 64 deletions fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""

import os
from typing import Callable

import paddle
Expand Down Expand Up @@ -52,6 +53,30 @@
m_grouped_fp8_gemm_nt_masked = None


def dump_tensor(*args):
"""打印张量的关键信息(类型、形状、 dtype、步长),用于调试精度问题"""
import inspect

frame = inspect.currentframe().f_back
try:
call = inspect.getframeinfo(frame).code_context[0]
names = call[call.find("(") + 1 : call.rfind(")")].split(",")
except Exception:
names = [f"arg{i}" for i in range(len(args))]

print(100 * "*")
for i, x in enumerate(args):
name = names[i].strip() if i < len(names) else f"arg{i}"
name_len = min(20, len(name))
print(
f"[{name[:name_len]:<23}] "
f"type={type(x).__name__:<12} "
f"shape={tuple(x.shape)!s:<18} "
f"dtype={str(x.dtype):<20} "
f"strides={x.strides}"
)


def m_grouped_fp8_gemm_nt_contiguous_custom_python_op_infermeta(
permute_input: "paddle.static.MetaTensor",
permute_scale: "paddle.static.MetaTensor",
Expand Down Expand Up @@ -99,9 +124,11 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
(permute_input.shape[0], layer_added_weight_attrs_0.shape[1]),
dtype=paddle.bfloat16,
)
if disable_ue8m0_cast:
# if disable_ue8m0_cast:
if permute_scale.strides[0] != 1:
permute_scale = permute_scale.transpose([1, 0]).contiguous()
permute_scale = permute_scale.transpose([1, 0])
dump_tensor(permute_input, permute_scale, m_indices)
# disable_ue8m0_cast is False for SM100
m_grouped_fp8_gemm_nt_contiguous(
(permute_input, permute_scale),
Expand Down Expand Up @@ -134,12 +161,16 @@ def m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
dtype=paddle.bfloat16,
)
# disable_ue8m0_cast is False for SM100

dump_tensor(ffn_in_x, ffn_in_x_scale_tensor)
m_grouped_fp8_gemm_nt_contiguous(
(ffn_in_x, ffn_in_x_scale_tensor),
(layer_added_weight_attrs_1, layer_added_scale_attrs_1),
ffn_out,
m_indices,
)

dump_tensor(ffn_out)
return ffn_out


Expand Down Expand Up @@ -256,6 +287,7 @@ def apply_ep_prefill(
gate: nn.Layer,
topk_ids_hookfunc: Callable = None,
) -> paddle.Tensor:
print("apply_ep_prefill")
"""
Apply the EP prefill method.
"""
Expand Down Expand Up @@ -317,30 +349,53 @@ def apply_ep_prefill(
(recv_x, recv_x_scale) = recv_x

token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
print("num_experts:", layer.num_local_experts)
dump_tensor(recv_x, recv_x_scale, recv_topk_weights, token_nums_this_rank[1])
if bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "1"))):
recv_topk_idx = recv_topk_idx.astype(paddle.int32)
(
permute_input,
permute_indices_per_token, # == zipped_expertwise_rowmap
dst_weights,
permute_scale,
m_indices,
) = paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=recv_x_scale,
expert_routemap_topk=recv_topk_idx,
expert_prob_topk=recv_topk_weights,
num_experts=layer.num_local_experts,
tokens_per_expert=token_nums_this_rank[1].tolist(),
padding_alignment=128,
return_expert_indices=True,
do_gather=True,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
else:
(
permute_input,
permute_scale,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
recv_num_tokens_per_expert_list_padded_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_scale,
recv_topk_idx,
recv_topk_weights,
token_nums_this_rank[0],
token_nums_this_rank[1],
True, # use_in_ep
token_all_num,
)

(
permute_input,
permute_scale,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
recv_num_tokens_per_expert_list_padded_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_scale,
recv_topk_idx,
recv_topk_weights,
token_nums_this_rank[0],
token_nums_this_rank[1],
True, # use_in_ep
token_all_num,
)
assert permute_input.shape[0] == token_all_num

if not self.quant_config.deepgemm_scale_ue8m0:
if permute_scale.strides[0] != 1:
permute_scale = permute_scale.transpose([1, 0]).contiguous().transpose([1, 0])

# up_gate_proj
Expand Down Expand Up @@ -387,21 +442,34 @@ def apply_ep_prefill(
m_indices,
)
del ffn_in_x
if bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "1"))):
print("use_phi_moe_permute", 60 * "*")
tmp_ffn_out, out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=recv_topk_idx,
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_local_experts,
# use_mix_precision =False,
using_weighted_combine=True,
)

# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias
False, # norm_topk_prob
1.0,
)
else:
# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias
False, # norm_topk_prob
1.0,
)
del ffn_out
else:
tmp_ffn_out = paddle.empty([0, hidden_size], paddle.bfloat16)

dump_tensor(tmp_ffn_out)
# 5. EP combine
event = deep_ep.Buffer.capture()
let_another_thread_run()
Expand All @@ -422,6 +490,7 @@ def apply_ep_decode(
"""
Apply the EP decoder method.
"""
print("apply_ep_decode")
gate_out = gate(x.cast("float32"))
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_decoder_runner.moe_select(layer, gate_out)
Expand Down Expand Up @@ -498,6 +567,7 @@ def apply_tp(
Paddle Use DeepGemm compute Fused MoE.
below is TP compute method.
"""
print("apply_tp")
gate_out = gate(x.cast("float32"))

if layer.topk_method == "noaux_tc":
Expand Down Expand Up @@ -527,7 +597,6 @@ def apply_tp(
if not fastdeploy.envs.FD_USE_PHI_FP8_QUANT:
recv_x, recv_x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(x, 128)
else:

recv_x, recv_x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
x,
using_pow2_scale=self.quant_config.deepgemm_scale_ue8m0,
Expand All @@ -540,26 +609,54 @@ def apply_tp(
else recv_x_scale.T[: recv_x.shape[0]]
)

(
permute_input,
permute_scale,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
recv_num_tokens_per_expert_list_padded_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_scale,
topk_ids,
topk_weights,
tmp[0],
tmp[1],
False, # use_in_ep
-1,
)
dump_tensor(recv_x, recv_x_scale, topk_ids, topk_weights, tmp[1])
print()
if bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "1"))):
topk_ids = topk_ids.astype(paddle.int32)
print("tp_phi_moe_permute")

print("layer.num_experts:", layer.num_experts)
(
permute_input,
permute_indices_per_token, # == zipped_expertwise_rowmap
dst_weights,
permute_scale,
m_indices,
) = paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=recv_x_scale,
expert_routemap_topk=topk_ids,
expert_prob_topk=topk_weights,
num_experts=layer.num_experts,
tokens_per_expert=tmp[0].tolist(),
padding_alignment=128,
return_expert_indices=True,
using_tp_alloc=True,
do_gather=True,
using_ue8m0_scale=self.quant_config.deepgemm_scale_ue8m0,
)
else:
(
permute_input,
permute_scale,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
recv_num_tokens_per_expert_list_padded_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
m_indices,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch_fp8(
recv_x,
recv_x_scale,
topk_ids,
topk_weights,
tmp[0],
tmp[1],
False, # use_in_ep
-1,
)
dump_tensor(permute_input, permute_indices_per_token, dst_weights, permute_scale, m_indices)

ffn_out = m_grouped_fp8_gemm_nt_contiguous_custom_python_op(
permute_input,
Expand All @@ -574,14 +671,31 @@ def apply_tp(
)

# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None,
False, # norm_topk_prob
1.0,
)

print("tp_phi_moe_unpermute")
dump_tensor(ffn_out, permute_indices_per_token, topk_ids, dst_weights)
if bool(int(os.getenv("FD_USE_PHI_MOE_PERMUTE", "1"))):
print("total_zipped_tokens:", ffn_out.shape[0])
print("num_experts:", layer.num_experts)

tmp_ffn_out, out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=topk_ids,
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_experts,
# use_mix_precision =False,
using_weighted_combine=True,
)
else:
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None,
False, # norm_topk_prob
1.0,
)
dump_tensor(tmp_ffn_out)
return tmp_ffn_out
Loading