Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,13 @@ RUN if [ "${ENABLE_NIXL}" = "1" ] || [ "${ENABLE_DEEPEP}" = "1" ]; then \
RUN if [ "${ENABLE_DEEPEP}" = "1" ]; then \
set -e; \
ln -sf /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so; \
NVSHMEM_VERSION=3.3.9; \
CUDA_ARCHS=90; \
wget https://developer.download.nvidia.com/compute/redist/nvshmem/${NVSHMEM_VERSION}/source/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \
&& tar -xf nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz && mv nvshmem_src nvshmem \
&& cd nvshmem \
&& rm -f /root/nvshmem_src_cuda12-all-all-${NVSHMEM_VERSION}.tar.gz \
&& NVSHMEM_SHMEM_SUPPORT=0 \
NVSHMEM_UCX_SUPPORT=0 \
NVSHMEM_USE_NCCL=0 \
NVSHMEM_MPI_SUPPORT=0 \
NVSHMEM_IBGDA_SUPPORT=1 \
NVSHMEM_PMIX_SUPPORT=0 \
NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \
NVSHMEM_USE_GDRCOPY=1 \
cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/root/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCHS} \
&& cmake --build build --target install -j64; \
DEEPEP_COMMIT=b6ce310bb0b75079682d09bc2ebc063a074fbd58; \
cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout ${DEEPEP_COMMIT} && cd ..; \
cd /root/DeepEP && NVSHMEM_DIR=/root/nvshmem/install python setup.py install; \
python -m pip install --upgrade --no-deps \
"nvidia-nccl-cu12==2.30.4" \
"nvidia-nvshmem-cu12==3.5.21"; \
cd /root && git clone https://github.com/deepseek-ai/DeepEP.git && cd DeepEP && git checkout b306af06afd412c88e51e71802951606e40b7358; \
ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so.3 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nvshmem/lib/libnvshmem_host.so; \
ln -sf /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so.2 /opt/conda/lib/python${PYTHON_VERSION}/site-packages/nvidia/nccl/lib/libnccl.so; \
pip install --no-build-isolation .; \
fi

RUN if [ "${ENABLE_NIXL}" = "1" ]; then \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
self.quant_method = quant_method
assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now."
self.enable_ep_moe = get_env_start_args().enable_ep_moe
self.quant_method = self._maybe_upgrade_quant_method_for_ep_moe(self.quant_method)
self.n_routed_experts = n_routed_experts
self.num_fused_shared_experts = num_fused_shared_experts
self._init_config(network_config)
Expand All @@ -66,6 +67,27 @@ def __init__(
self.lock = threading.Lock()
self._create_weight()

def _maybe_upgrade_quant_method_for_ep_moe(self, quant_method: QuantizationMethod) -> QuantizationMethod:
if not self.enable_ep_moe:
return quant_method

if quant_method.method_name == "none":
from lightllm.common.quantization.registry import QUANTMETHODS

logger.info(
"enable_ep_moe requires FP8 MoE expert weights; "
"auto-upgrading fused_moe quantization from `none` to `deepgemm-fp8w8a8-b128`."
)
quant_method = QUANTMETHODS.get("deepgemm-fp8w8a8-b128")

if quant_method.method_name != "deepgemm-fp8w8a8-b128":
raise ValueError(
f"enable_ep_moe currently only supports `deepgemm-fp8w8a8-b128` for fused_moe, "
f"but got `{quant_method.method_name}`."
)

return quant_method

def _init_config(self, network_config: Dict[str, Any]):
self.n_group = network_config.get("n_group", 0)
self.use_grouped_topk = self.n_group > 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from lightllm.distributed import dist_group_manager
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.common.quantization.quantize_method import WeightPack
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
from lightllm.utils.envs_utils import (
get_deepep_num_max_dispatch_tokens_per_rank_prefill,
get_deepep_num_max_dispatch_tokens_per_rank_decode,
)
from lightllm.common.basemodel.triton_kernel.fused_moe.grouped_fused_moe_ep import (
fused_experts_impl,
masked_group_gemm,
Expand All @@ -20,6 +23,9 @@


class FuseMoeDeepGEMM(FuseMoeTriton):
def _get_ep_num_sms(self) -> int:
return getattr(dist_group_manager, "ep_num_sms", None) or 0

def _select_experts(
self,
input_tensor: torch.Tensor,
Expand Down Expand Up @@ -73,14 +79,15 @@ def _fused_experts(
w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
use_fp8_w8a8 = self.quant_method.method_name != "none"
buffer = dist_group_manager.ep_buffer if is_prefill else dist_group_manager.ep_low_latency_buffer
output = fused_experts_impl(
hidden_states=input_tensor,
w1=w13_weight,
w2=w2_weight,
topk_weights=topk_weights,
topk_idx=topk_ids.to(torch.long),
num_experts=self.total_expert_num_contain_redundancy, # number of all experts contain redundancy
buffer=dist_group_manager.ep_buffer,
buffer=buffer,
is_prefill=is_prefill,
use_fp8_w8a8=use_fp8_w8a8,
use_fp8_all2all=use_fp8_w8a8,
Expand Down Expand Up @@ -116,13 +123,13 @@ def low_latency_dispatch(
)

topk_idx = topk_idx.to(torch.long)
num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank()
num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode()
use_fp8_w8a8 = self.quant_method.method_name != "none"
recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch(
hidden_states,
topk_idx,
num_max_dispatch_tokens_per_rank,
self.total_expert_num_contain_redundancy,
recv_x, masked_m, handle, event, hook = dist_group_manager.ep_low_latency_buffer.low_latency_dispatch(
topk_idx=topk_idx,
x=hidden_states,
num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank,
num_experts=self.total_expert_num_contain_redundancy,
use_fp8=use_fp8_w8a8,
async_finish=False,
return_recv_hook=True,
Expand Down Expand Up @@ -169,38 +176,26 @@ def dispatch(
overlap_event: Optional[Any] = None,
):
buffer = dist_group_manager.ep_buffer
# get_dispatch_layout
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
previous_event,
) = buffer.get_dispatch_layout(
topk_idx,
self.total_expert_num_contain_redundancy,
previous_event=overlap_event,
async_finish=True,
allocate_on_comm_stream=True,
)
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch(
num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_prefill()
recv_x, recv_topk_idx, recv_topk_weights, handle, event = buffer.dispatch(
qinput_tensor,
topk_idx=topk_idx,
topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event,
async_finish=True,
allocate_on_comm_stream=True,
num_experts=self.total_expert_num_contain_redundancy,
num_max_tokens_per_rank=num_max_tokens_per_rank,
expert_alignment=128,
num_sms=self._get_ep_num_sms(),
previous_event=overlap_event,
async_with_compute_stream=True,
allocate_on_comm_stream=True,
do_cpu_sync=True,
do_handle_copy=False,
)

def hook():
event.current_stream_wait()

return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook
return recv_x, recv_topk_idx, recv_topk_weights, handle.num_recv_tokens_per_expert_list, handle, hook

def masked_group_gemm(
self,
Expand Down Expand Up @@ -310,7 +305,7 @@ def low_latency_combine(
topk_weights: torch.Tensor,
handle: Any,
):
combined_x, event_overlap, hook = dist_group_manager.ep_buffer.low_latency_combine(
combined_x, event_overlap, hook = dist_group_manager.ep_low_latency_buffer.low_latency_combine(
gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=True
)
return combined_x, hook
Expand All @@ -326,8 +321,9 @@ def combine(
gemm_out_b,
handle,
topk_weights=None,
async_finish=True,
num_sms=self._get_ep_num_sms(),
previous_event=overlap_event,
async_with_compute_stream=True,
allocate_on_comm_stream=True,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""Fused MoE kernel."""
import os
import torch
import triton
import triton.language as tl
from typing import Any, Callable, Dict, Optional, Tuple
import torch.distributed as dist
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul_mix_quant_ep import (
Expand All @@ -15,9 +12,11 @@
tma_align_input_scale,
)
from lightllm.common.basemodel.triton_kernel.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
from lightllm.utils.envs_utils import (
get_deepep_num_max_dispatch_tokens_per_rank_prefill,
get_deepep_num_max_dispatch_tokens_per_rank_decode,
)
from lightllm.common.triton_utils.autotuner import Autotuner
import numpy as np

logger = init_logger(__name__)

Expand Down Expand Up @@ -66,14 +65,14 @@ def fused_experts_impl(
topk_weights: torch.Tensor, # [M, topk]
topk_idx: torch.Tensor, # [M, topk]
num_experts: int,
buffer: "Buffer",
buffer: Any,
is_prefill: bool,
use_fp8_w8a8: bool = False,
use_fp8_all2all: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
previous_event: Optional["EventOverlap"] = None,
previous_event: Optional[EventOverlap] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using EventOverlap directly in the type hint will cause a NameError at import time if deep_ep is not installed, as the import is wrapped in a try...except block. Please use a string literal for the type hint to maintain compatibility with environments where deep_ep might be missing.

Suggested change
previous_event: Optional[EventOverlap] = None,
previous_event: Optional["EventOverlap"] = None,

):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
Expand All @@ -99,39 +98,27 @@ def fused_experts_impl(
combined_x = None
if is_prefill:
qinput_tensor, input_scale = per_token_group_quant_fp8(hidden_states, block_size_k, dtype=w1.dtype)

# get_dispatch_layout
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
previous_event,
) = buffer.get_dispatch_layout(
topk_idx, num_experts, previous_event=previous_event, async_finish=False, allocate_on_comm_stream=False
)

allocate_on_comm_stream = previous_event is not None
# normal dispatch
# recv_x [recive_num_tokens, hidden] recv_x_scale [recive_num_tokens, hidden // block_size]
# recv_topk_idx [recive_num_tokens, topk_num]
# recv_topk_weights [recive_num_tokens, topk_num]
# num_recv_tokens_per_expert_list list [cur_node_expert_num] padding with expert_alignment=128
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch(
recv_x, recv_topk_idx, recv_topk_weights, handle, _ = buffer.dispatch(
(qinput_tensor, input_scale),
topk_idx=topk_idx,
topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event,
async_finish=False,
allocate_on_comm_stream=False,
num_experts=num_experts,
num_max_tokens_per_rank=get_deepep_num_max_dispatch_tokens_per_rank_prefill(),
expert_alignment=128,
previous_event=previous_event,
allocate_on_comm_stream=allocate_on_comm_stream,
do_cpu_sync=True,
do_handle_copy=False,
)

# scatter
all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums.
all_tokens = sum(handle.num_recv_tokens_per_expert_list) # calcu padding all nums.
# gather_out shape [recive_num_tokens, hidden]
gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype)
if all_tokens > 0:
Expand All @@ -149,7 +136,7 @@ def fused_experts_impl(
output_index = torch.empty_like(recv_topk_idx)

num_recv_tokens_per_expert = torch.tensor(
num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu"
handle.num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu"
).cuda(non_blocking=True)

expert_start_loc = torch.empty_like(num_recv_tokens_per_expert)
Expand Down Expand Up @@ -202,13 +189,12 @@ def fused_experts_impl(
gather_out,
handle,
topk_weights=None,
async_finish=False,
previous_event=previous_event,
allocate_on_comm_stream=False,
allocate_on_comm_stream=allocate_on_comm_stream,
)
else:
# low latency dispatch
num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank()
num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode()
expected_m = triton.cdiv(hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1], num_experts)
recv_x, masked_m, handle, event, hook = buffer.low_latency_dispatch(
hidden_states,
Expand Down
Loading
Loading