Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 72 additions & 26 deletions dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ class GraphParams:
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
is_mla: bool


_graph_params: Optional[GraphParams] = None
Expand All @@ -470,6 +471,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
{size: None for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
{size: [] for size in aclgraph_capture_sizes},
False,
)


Expand All @@ -490,6 +492,7 @@ def clear_graph_params():
_graph_params.handles[k].clear()
for k in list(_graph_params.events.keys()):
_graph_params.events[k].clear()
_graph_params.is_mla = None

_graph_params.workspaces.clear()
finally:
Expand All @@ -503,30 +506,73 @@ def update_attn_params(update_stream, forward_meta, runtime_size):
graph_params.handles[runtime_size],
graph_params.events[runtime_size],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_meta.input_buffers["kv_seqlens"]
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch.ops.atb._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
if graph_params.is_mla:
update_decode_attention_mla_params(
update_stream, forward_meta, param, handle, event
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
else:
update_decode_attention_params(
update_stream, forward_meta, param, handle, event
)


def update_decode_attention_params(update_stream, forward_meta, param, handle, event):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
kv_seq_len,
output,
) = param
kv_seq_len = forward_meta.input_buffers["kv_seqlens"]
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch.ops.atb._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=kv_seq_len,
out=output,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)


def update_decode_attention_mla_params(
update_stream, forward_meta, param, handle, event
):
(
query,
key_cache,
num_kv_heads,
num_q_heads,
scale_value,
block_table,
kv_seq_len,
mla_vheadsize,
attn_output,
) = param
kv_seq_len = forward_meta.input_buffers["kv_seqlens"]
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch.ops.atb._npu_paged_attention_mla(
query=query,
key_cache=key_cache,
num_kv_heads=num_kv_heads,
num_heads=num_q_heads,
scale_value=scale_value,
block_table=block_table,
context_lens=kv_seq_len,
mla_vheadsize=mla_vheadsize,
out=attn_output,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
25 changes: 25 additions & 0 deletions dlinfer/framework/lmdeploy_ext/device/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
)
from lmdeploy.utils import get_logger

# dp ep
from lmdeploy.pytorch.backends.dlinfer.ascend import AscendOpsBackend
from dlinfer.utils.type_annotation import DlinferDistContext


def rl_update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):
"""Update weights."""
Expand Down Expand Up @@ -980,3 +984,24 @@ def get_assignment_batch(
cache_engine.CacheEngine = AscendCacheEngine
executor_base.CacheEngine = AscendCacheEngine
model_agent.CacheEngine = AscendCacheEngine


def get_max_tokens_accros_dp():
return AscendOpsBackend.max_tokens_accros_dp


def get_pad_size(dist_ctx: DlinferDistContext, actual_size: int):
@functools.lru_cache(maxsize=1024)
def inner(max_tokens_accros_dp: int, ep_size: int, tp_size: int, actual_size: int):
if ep_size > 1:
paded_size = (max_tokens_accros_dp + tp_size - 1) // tp_size * tp_size
pad_size = paded_size - actual_size
return pad_size
return 0

return inner(
get_max_tokens_accros_dp(),
dist_ctx.ep_size,
dist_ctx.tp_size,
actual_size,
)
32 changes: 22 additions & 10 deletions dlinfer/ops/llm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
# Copyright (c) 2024, DeepLink. All rights reserved.
import torch
from typing import List
from dlinfer.vendor import vendor_ops_registry
from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple
from dlinfer.utils.type_annotation import (
Tensor,
Optional,
Sequence,
Tuple,
DlinferDistContext,
linear_w8a8_scale_type,
dynamic_quant_scale_type,
)
from dlinfer.graph.custom_op import register_custom_op
from dlinfer.vendor import linear_w8a8_scale_type, dynamic_quant_scale_type


__all__ = [
Expand Down Expand Up @@ -465,7 +473,7 @@ def silu_and_mul(


def moe_gating_topk_softmax_impl_abstract_func(
router_logits: Tensor, topk: int
router_logits: Tensor, topk: int, dist_ctx: DlinferDistContext = None
) -> Tuple[Tensor, Tensor]:
routing_weights = router_logits.new_empty((*router_logits.shape[:-1], topk))
selected_experts = router_logits.new_empty(
Expand All @@ -474,11 +482,13 @@ def moe_gating_topk_softmax_impl_abstract_func(
return routing_weights, selected_experts


@register_custom_op(
"dlinfer::moe_gating_topk_softmax",
impl_abstract_func=moe_gating_topk_softmax_impl_abstract_func,
)
def moe_gating_topk_softmax(router_logits: Tensor, topk: int) -> Tuple[Tensor, Tensor]:
# @register_custom_op(
# "dlinfer::moe_gating_topk_softmax",
# impl_abstract_func=moe_gating_topk_softmax_impl_abstract_func,
# )
def moe_gating_topk_softmax(
router_logits: Tensor, topk: int, dist_ctx: DlinferDistContext
) -> Tuple[Tensor, Tensor]:
"""
Given router_logits of experts, it computes the probability distributions of experts
and then selecting topk values and their corresponding indices.
Expand All @@ -492,7 +502,7 @@ def moe_gating_topk_softmax(router_logits: Tensor, topk: int) -> Tuple[Tensor, T
- The router weight of selected experts.
- The index of selected experts.
"""
return vendor_ops_registry["moe_gating_topk_softmax"](router_logits, topk)
return vendor_ops_registry["moe_gating_topk_softmax"](router_logits, topk, dist_ctx)


# TODO only for internlm on transformers lib.
Expand Down Expand Up @@ -585,7 +595,7 @@ def weight_quant_matmul(
)


@register_custom_op("dlinfer::fused_moe", ["hidden_states"])
# @register_custom_op("dlinfer::fused_moe", ["hidden_states"])
def fused_moe(
hidden_states: Tensor,
gate_up_weights: Tensor,
Expand All @@ -594,6 +604,7 @@ def fused_moe(
topk_ids: Tensor,
topk: int,
renormalize: bool,
dist_ctx: DlinferDistContext,
) -> Tensor:
"""
Implement the Fused Mixture of Experts (MoE) model.
Expand All @@ -619,6 +630,7 @@ def fused_moe(
topk_ids,
topk,
renormalize,
dist_ctx,
)


Expand Down
22 changes: 22 additions & 0 deletions dlinfer/utils/type_annotation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
# Copyright (c) 2024, DeepLink. All rights reserved.
import torch
from torch import Tensor
from dataclasses import dataclass
from typing import Optional, Sequence, Union, Any, Tuple, Callable, Dict


@dataclass
class DlinferDistContext:
dp_size: int = 1
tp_size: int = 1
ep_size: int = 1

dp_rank: int = 0
tp_rank: int = 0
ep_rank: int = 0

max_tokens_accros_dp: int = 1

tp_group: torch.distributed.ProcessGroup = None
ep_group: torch.distributed.ProcessGroup = None


linear_w8a8_scale_type = torch.Tensor
dynamic_quant_scale_type = torch.Tensor
2 changes: 0 additions & 2 deletions dlinfer/vendor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
vendor_ops_registry = dict()
vendor_is_initialized = False
vendor_name_file = Path(__file__).parent / "vendor.yaml"
linear_w8a8_scale_type = torch.Tensor
dynamic_quant_scale_type = torch.Tensor


with open(str(vendor_name_file), "r") as f:
Expand Down
Loading
Loading